コード例 #1
0
        /// <summary>
        /// Create Empty PNG for initializing the buffer for stacking.
        /// </summary>
        internal byte[] CreateEmptyPNG()
        {
            int height    = m_WrappedSensor.GetObservationShape()[0];
            int width     = m_WrappedSensor.GetObservationShape()[1];
            var texture2D = new Texture2D(width, height, TextureFormat.RGB24, false);

            return(texture2D.EncodeToPNG());
        }
コード例 #2
0
        /// <summary>
        ///
        /// </summary>
        /// <param name="wrapped">The wrapped sensor</param>
        /// <param name="numStackedObservations">Number of stacked observations to keep</param>
        public StackingSensor(ISensor wrapped, int numStackedObservations)
        {
            // TODO ensure numStackedObservations > 1
            m_WrappedSensor          = wrapped;
            m_NumStackedObservations = numStackedObservations;

            m_Name = $"StackingSensor_size{numStackedObservations}_{wrapped.GetName()}";

            var shape = wrapped.GetObservationShape();

            m_Shape = new int[shape.Length];

            m_UnstackedObservationSize = wrapped.ObservationSize();
            for (int d = 0; d < shape.Length; d++)
            {
                m_Shape[d] = shape[d];
            }

            // TODO support arbitrary stacking dimension
            m_Shape[0]           *= numStackedObservations;
            m_StackedObservations = new float[numStackedObservations][];
            for (var i = 0; i < numStackedObservations; i++)
            {
                m_StackedObservations[i] = new float[m_UnstackedObservationSize];
            }
        }
コード例 #3
0
ファイル: StackingSensor.cs プロジェクト: wszhs/ml-agents
        /// <summary>
        /// Initializes the sensor.
        /// </summary>
        /// <param name="wrapped">The wrapped sensor.</param>
        /// <param name="numStackedObservations">Number of stacked observations to keep.</param>
        public StackingSensor(ISensor wrapped, int numStackedObservations)
        {
            // TODO ensure numStackedObservations > 1
            m_WrappedSensor          = wrapped;
            m_NumStackedObservations = numStackedObservations;

            m_Name = $"StackingSensor_size{numStackedObservations}_{wrapped.GetName()}";

            if (wrapped.GetCompressionType() != SensorCompressionType.None)
            {
                throw new UnityAgentsException("StackingSensor doesn't support compressed observations.'");
            }

            var shape = wrapped.GetObservationShape();

            if (shape.Length != 1)
            {
                throw new UnityAgentsException("Only 1-D observations are supported by StackingSensor");
            }
            m_Shape = new int[shape.Length];

            m_UnstackedObservationSize = wrapped.ObservationSize();
            for (int d = 0; d < shape.Length; d++)
            {
                m_Shape[d] = shape[d];
            }

            // TODO support arbitrary stacking dimension
            m_Shape[0]           *= numStackedObservations;
            m_StackedObservations = new float[numStackedObservations][];
            for (var i = 0; i < numStackedObservations; i++)
            {
                m_StackedObservations[i] = new float[m_UnstackedObservationSize];
            }
        }
コード例 #4
0
        public static void CompareObservation(ISensor sensor, float[] expected)
        {
            var         numExpected = expected.Length;
            const float fill        = -1337f;
            var         output      = new float[numExpected];

            for (var i = 0; i < numExpected; i++)
            {
                output[i] = fill;
            }
            Assert.AreEqual(fill, output[0]);

            ObservationWriter writer = new ObservationWriter();

            writer.SetTarget(output, sensor.GetObservationShape(), 0);

            // Make sure ObservationWriter didn't touch anything
            Assert.AreEqual(fill, output[0]);

            sensor.Write(writer);
            for (var i = 0; i < numExpected; i++)
            {
                Assert.AreEqual(expected[i], output[i]);
            }
        }
コード例 #5
0
        /// <summary>
        /// Generate an ObservationProto for the sensor using the provided WriteAdapter.
        /// This is equivalent to producing an Observation and calling Observation.ToProto(),
        /// but avoid some intermediate memory allocations.
        /// </summary>
        /// <param name="sensor"></param>
        /// <param name="writeAdapter"></param>
        /// <returns></returns>
        public static ObservationProto GetObservationProto(this ISensor sensor, WriteAdapter writeAdapter)
        {
            var shape = sensor.GetObservationShape();
            ObservationProto observationProto = null;

            if (sensor.GetCompressionType() == SensorCompressionType.None)
            {
                var numFloats      = sensor.ObservationSize();
                var floatDataProto = new ObservationProto.Types.FloatData();
                // Resize the float array
                // TODO upgrade protobuf versions so that we can set the Capacity directly - see https://github.com/protocolbuffers/protobuf/pull/6530
                for (var i = 0; i < numFloats; i++)
                {
                    floatDataProto.Data.Add(0.0f);
                }

                writeAdapter.SetTarget(floatDataProto.Data, sensor.GetObservationShape(), 0);
                sensor.Write(writeAdapter);

                observationProto = new ObservationProto
                {
                    FloatData       = floatDataProto,
                    CompressionType = (CompressionTypeProto)SensorCompressionType.None,
                };
            }
            else
            {
                var compressedObs = sensor.GetCompressedObservation();
                if (compressedObs == null)
                {
                    throw new UnityAgentsException(
                              $"GetCompressedObservation() returned null data for sensor named {sensor.GetName()}. " +
                              "You must return a byte[]. If you don't want to use compressed observations, " +
                              "return SensorCompressionType.None from GetCompressionType()."
                              );
                }

                observationProto = new ObservationProto
                {
                    CompressedData  = ByteString.CopyFrom(compressedObs),
                    CompressionType = (CompressionTypeProto)sensor.GetCompressionType(),
                };
            }
            observationProto.Shape.AddRange(shape);
            return(observationProto);
        }
コード例 #6
0
        /// <summary>
        /// Create Empty PNG for initializing the buffer for stacking.
        /// </summary>
        internal byte[] CreateEmptyPNG()
        {
            int height    = m_WrappedSensor.GetObservationShape()[0];
            int width     = m_WrappedSensor.GetObservationShape()[1];
            var texture2D = new Texture2D(width, height, TextureFormat.RGB24, false);

            Color32[] resetColorArray = texture2D.GetPixels32();
            Color32   black           = new Color32(0, 0, 0, 0);

            for (int i = 0; i < resetColorArray.Length; i++)
            {
                resetColorArray[i] = black;
            }
            texture2D.SetPixels32(resetColorArray);
            texture2D.Apply();
            return(texture2D.EncodeToPNG());
        }
コード例 #7
0
        /// <summary>
        /// Generates the observations for the provided sensor, and returns true if they equal the
        /// expected values. If they are unequal, errorMessage is also set.
        /// This should not generally be used in production code. It is only intended for
        /// simplifying unit tests.
        /// </summary>
        /// <param name="sensor"></param>
        /// <param name="expected"></param>
        /// <param name="errorMessage"></param>
        /// <returns></returns>
        public static bool CompareObservation(ISensor sensor, float[,,] expected, out string errorMessage)
        {
            var         tensorShape = new TensorShape(0, expected.GetLength(0), expected.GetLength(1), expected.GetLength(2));
            var         numExpected = tensorShape.height * tensorShape.width * tensorShape.channels;
            const float fill        = -1337f;
            var         output      = new float[numExpected];

            for (var i = 0; i < numExpected; i++)
            {
                output[i] = fill;
            }

            if (numExpected > 0)
            {
                if (fill != output[0])
                {
                    errorMessage = "Error setting output buffer.";
                    return(false);
                }
            }

            ObservationWriter writer = new ObservationWriter();

            writer.SetTarget(output, sensor.GetObservationShape(), 0);

            // Make sure ObservationWriter didn't touch anything
            if (numExpected > 0)
            {
                if (fill != output[0])
                {
                    errorMessage = "ObservationWriter.SetTarget modified a buffer it shouldn't have.";
                    return(false);
                }
            }

            sensor.Write(writer);
            for (var h = 0; h < tensorShape.height; h++)
            {
                for (var w = 0; w < tensorShape.width; w++)
                {
                    for (var c = 0; c < tensorShape.channels; c++)
                    {
                        if (expected[h, w, c] != output[tensorShape.Index(0, h, w, c)])
                        {
                            errorMessage = $"Expected and actual differed in position [{h}, {w}, {c}]. " +
                                           $"Expected: {expected[h, w, c]}  Actual: {output[tensorShape.Index(0, h, w, c)]} ";
                            return(false);
                        }
                    }
                }
            }
            errorMessage = null;
            return(true);
        }
コード例 #8
0
        /// <summary>
        /// Get the total number of elements in the ISensor's observation (i.e. the product of the
        /// shape elements).
        /// </summary>
        /// <param name="sensor"></param>
        /// <returns></returns>
        public static int ObservationSize(this ISensor sensor)
        {
            var shape = sensor.GetObservationShape();
            var count = 1;

            foreach (var dim in shape)
            {
                count *= dim;
            }

            return(count);
        }
コード例 #9
0
        /// <summary>
        /// Get the total number of elements in the ISensor's observation (i.e. the product of the shape elements).
        /// </summary>
        /// <param name="sensor"></param>
        /// <returns></returns>
        public static int ObservationSize(this ISensor sensor)
        {
            var shape = sensor.GetObservationShape();
            int count = 1;

            for (var i = 0; i < shape.Length; i++)
            {
                count *= shape[i];
            }

            return(count);
        }
コード例 #10
0
        /// <summary>
        /// Generate an ObservationProto for the sensor using the provided WriteAdapter.
        /// This is equivalent to producing an Observation and calling Observation.ToProto(),
        /// but avoid some intermediate memory allocations.
        /// </summary>
        /// <param name="sensor"></param>
        /// <param name="writeAdapter"></param>
        /// <returns></returns>
        public static ObservationProto GetObservationProto(this ISensor sensor, WriteAdapter writeAdapter)
        {
            var shape = sensor.GetObservationShape();
            ObservationProto observationProto = null;

            if (sensor.GetCompressionType() == SensorCompressionType.None)
            {
                var numFloats      = sensor.ObservationSize();
                var floatDataProto = new ObservationProto.Types.FloatData();
                // Resize the float array
                // TODO upgrade protobuf versions so that we can set the Capacity directly - see https://github.com/protocolbuffers/protobuf/pull/6530
                for (var i = 0; i < numFloats; i++)
                {
                    floatDataProto.Data.Add(0.0f);
                }

                writeAdapter.SetTarget(floatDataProto.Data, sensor.GetObservationShape(), 0);
                sensor.Write(writeAdapter);

                observationProto = new ObservationProto
                {
                    FloatData       = floatDataProto,
                    CompressionType = (CompressionTypeProto)SensorCompressionType.None,
                };
            }
            else
            {
                observationProto = new ObservationProto
                {
                    CompressedData  = ByteString.CopyFrom(sensor.GetCompressedObservation()),
                    CompressionType = (CompressionTypeProto)sensor.GetCompressionType(),
                };
            }
            observationProto.Shape.AddRange(shape);
            return(observationProto);
        }
コード例 #11
0
        /// <summary>
        /// Constrct stacked CompressedChannelMapping.
        /// </summary>
        internal int[] ConstructStackedCompressedChannelMapping(ISensor wrappedSenesor)
        {
            // Get CompressedChannelMapping of the wrapped sensor. If the
            // wrapped sensor doesn't have one, use default mapping.
            // Default mapping: {0, 0, 0} for grayscale, identity mapping {1, 2, ..., n} otherwise.
            int[] wrappedMapping      = null;
            int   wrappedNumChannel   = wrappedSenesor.GetObservationShape()[2];
            var   sparseChannelSensor = m_WrappedSensor as ISparseChannelSensor;

            if (sparseChannelSensor != null)
            {
                wrappedMapping = sparseChannelSensor.GetCompressedChannelMapping();
            }
            if (wrappedMapping == null)
            {
                if (wrappedNumChannel == 1)
                {
                    wrappedMapping = new int[] { 0, 0, 0 };
                }
                else
                {
                    wrappedMapping = Enumerable.Range(0, wrappedNumChannel).ToArray();
                }
            }

            // Construct stacked mapping using the mapping of wrapped sensor.
            // First pad the wrapped mapping to multiple of 3, then repeat
            // and add offset to each copy to form the stacked mapping.
            int paddedMapLength    = (wrappedMapping.Length + 2) / 3 * 3;
            var compressionMapping = new int[paddedMapLength * m_NumStackedObservations];

            for (var i = 0; i < m_NumStackedObservations; i++)
            {
                var offset = wrappedNumChannel * i;
                for (var j = 0; j < paddedMapLength; j++)
                {
                    if (j < wrappedMapping.Length)
                    {
                        compressionMapping[j + paddedMapLength * i] = wrappedMapping[j] >= 0 ? wrappedMapping[j] + offset : -1;
                    }
                    else
                    {
                        compressionMapping[j + paddedMapLength * i] = -1;
                    }
                }
            }
            return(compressionMapping);
        }
コード例 #12
0
        /// <summary>
        /// Generates the observations for the provided sensor, and returns true if they equal the
        /// expected values. If they are unequal, errorMessage is also set.
        /// This should not generally be used in production code. It is only intended for
        /// simplifying unit tests.
        /// </summary>
        /// <param name="sensor"></param>
        /// <param name="expected"></param>
        /// <param name="errorMessage"></param>
        /// <returns></returns>
        public static bool CompareObservation(ISensor sensor, float[] expected, out string errorMessage)
        {
            var         numExpected = expected.Length;
            const float fill        = -1337f;
            var         output      = new float[numExpected];

            for (var i = 0; i < numExpected; i++)
            {
                output[i] = fill;
            }

            if (numExpected > 0)
            {
                if (fill != output[0])
                {
                    errorMessage = "Error setting output buffer.";
                    return(false);
                }
            }

            ObservationWriter writer = new ObservationWriter();

            writer.SetTarget(output, sensor.GetObservationShape(), 0);

            // Make sure ObservationWriter didn't touch anything
            if (numExpected > 0)
            {
                if (fill != output[0])
                {
                    errorMessage = "ObservationWriter.SetTarget modified a buffer it shouldn't have.";
                    return(false);
                }
            }

            sensor.Write(writer);
            for (var i = 0; i < output.Length; i++)
            {
                if (expected[i] != output[i])
                {
                    errorMessage = $"Expected and actual differed in position {i}. Expected: {expected[i]}  Actual: {output[i]} ";
                    return(false);
                }
            }

            errorMessage = null;
            return(true);
        }
コード例 #13
0
        /// <summary>
        /// Checks that the shape of the rank 2 observation input placeholder is the same as the corresponding sensor.
        /// </summary>
        /// <param name="tensorProxy">The tensor that is expected by the model</param>
        /// <param name="sensor">The sensor that produces the visual observation.</param>
        /// <returns>
        /// If the Check failed, returns a string containing information about why the
        /// check failed. If the check passed, returns null.
        /// </returns>
        static string CheckRankTwoObsShape(
            TensorProxy tensorProxy, ISensor sensor)
        {
            var shape  = sensor.GetObservationShape();
            var dim1Bp = shape[0];
            var dim2Bp = shape[1];
            var dim1T  = tensorProxy.Channels;
            var dim2T  = tensorProxy.Width;

            if ((dim1Bp != dim1T) || (dim2Bp != dim2T))
            {
                return($"An Observation of the model does not match. " +
                       $"Received TensorProxy of shape [?x{dim1Bp}x{dim2Bp}] but " +
                       $"was expecting [?x{dim1T}x{dim2T}].");
            }
            return(null);
        }
コード例 #14
0
ファイル: Events.cs プロジェクト: zereyak13/ml-agents
        public static EventObservationSpec FromSensor(ISensor sensor)
        {
            var shape    = sensor.GetObservationShape();
            var dimInfos = new EventObservationDimensionInfo[shape.Length];

            for (var i = 0; i < shape.Length; i++)
            {
                dimInfos[i].Size = shape[i];
                // TODO copy flags when we have them
            }

            return(new EventObservationSpec
            {
                SensorName = sensor.GetName(),
                CompressionType = sensor.GetCompressionType().ToString(),
                DimensionInfos = dimInfos,
            });
        }
コード例 #15
0
        /// <summary>
        /// Initializes the sensor.
        /// </summary>
        /// <param name="wrapped">The wrapped sensor.</param>
        /// <param name="numStackedObservations">Number of stacked observations to keep.</param>
        public StackingSensor(ISensor wrapped, int numStackedObservations)
        {
            // TODO ensure numStackedObservations > 1
            m_WrappedSensor          = wrapped;
            m_NumStackedObservations = numStackedObservations;

            m_Name = $"StackingSensor_size{numStackedObservations}_{wrapped.GetName()}";

            m_WrappedShape = wrapped.GetObservationShape();
            m_Shape        = new int[m_WrappedShape.Length];

            m_UnstackedObservationSize = wrapped.ObservationSize();
            for (int d = 0; d < m_WrappedShape.Length; d++)
            {
                m_Shape[d] = m_WrappedShape[d];
            }

            // TODO support arbitrary stacking dimension
            m_Shape[m_Shape.Length - 1] *= numStackedObservations;

            // Initialize uncompressed buffer anyway in case python trainer does not
            // support the compression mapping and has to fall back to uncompressed obs.
            m_StackedObservations = new float[numStackedObservations][];
            for (var i = 0; i < numStackedObservations; i++)
            {
                m_StackedObservations[i] = new float[m_UnstackedObservationSize];
            }

            if (m_WrappedSensor.GetCompressionType() != SensorCompressionType.None)
            {
                m_StackedCompressedObservations = new byte[numStackedObservations][];
                m_EmptyCompressedObservation    = CreateEmptyPNG();
                for (var i = 0; i < numStackedObservations; i++)
                {
                    m_StackedCompressedObservations[i] = m_EmptyCompressedObservation;
                }
                m_CompressionMapping = ConstructStackedCompressedChannelMapping(wrapped);
            }

            if (m_Shape.Length != 1)
            {
                m_tensorShape = new TensorShape(0, m_WrappedShape[0], m_WrappedShape[1], m_WrappedShape[2]);
            }
        }
コード例 #16
0
        /// <summary>
        /// Checks that the shape of the visual observation input placeholder is the same as the corresponding sensor.
        /// </summary>
        /// <param name="tensorProxy">The tensor that is expected by the model</param>
        /// <param name="sensor">The sensor that produces the visual observation.</param>
        /// <returns>
        /// If the Check failed, returns a string containing information about why the
        /// check failed. If the check passed, returns null.
        /// </returns>
        static string CheckVisualObsShape(
            TensorProxy tensorProxy, ISensor sensor)
        {
            var shape    = sensor.GetObservationShape();
            var heightBp = shape[0];
            var widthBp  = shape[1];
            var pixelBp  = shape[2];
            var heightT  = tensorProxy.Height;
            var widthT   = tensorProxy.Width;
            var pixelT   = tensorProxy.Channels;

            if ((widthBp != widthT) || (heightBp != heightT) || (pixelBp != pixelT))
            {
                return($"The visual Observation of the model does not match. " +
                       $"Received TensorProxy of shape [?x{widthBp}x{heightBp}x{pixelBp}] but " +
                       $"was expecting [?x{widthT}x{heightT}x{pixelT}].");
            }
            return(null);
        }
コード例 #17
0
        public int Write(WriteAdapter adapter)
        {
            // First, call the wrapped sensor's write method. Make sure to use our own adapater, not the passed one.
            var wrappedShape = m_WrappedSensor.GetObservationShape();

            m_LocalAdapter.SetTarget(m_StackedObservations[m_CurrentIndex], wrappedShape, 0);
            m_WrappedSensor.Write(m_LocalAdapter);

            // Now write the saved observations (oldest first)
            var numWritten = 0;

            for (var i = 0; i < m_NumStackedObservations; i++)
            {
                var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations;
                adapter.AddRange(m_StackedObservations[obsIndex], numWritten);
                numWritten += m_UnstackedObservationSize;
            }

            return(numWritten);
        }
コード例 #18
0
ファイル: Events.cs プロジェクト: PedroLelis/ml-agents
        public static EventObservationSpec FromSensor(ISensor sensor)
        {
            var shape    = sensor.GetObservationShape();
            var dimInfos = new EventObservationDimensionInfo[shape.Length];

            for (var i = 0; i < shape.Length; i++)
            {
                dimInfos[i].Size = shape[i];
                // TODO copy flags when we have them
            }

            var builtInSensorType =
                (sensor as IBuiltInSensor)?.GetBuiltInSensorType() ?? Sensors.BuiltInSensorType.Unknown;

            return(new EventObservationSpec
            {
                SensorName = sensor.GetName(),
                CompressionType = sensor.GetCompressionType().ToString(),
                BuiltInSensorType = (int)builtInSensorType,
                DimensionInfos = dimInfos,
            });
        }
コード例 #19
0
ファイル: Events.cs プロジェクト: ishitavohra3110/ml-agents
        public static EventObservationSpec FromSensor(ISensor sensor)
        {
            var shape    = sensor.GetObservationShape();
            var dimProps = (sensor as IDimensionPropertiesSensor)?.GetDimensionProperties();
            var dimInfos = new EventObservationDimensionInfo[shape.Length];

            for (var i = 0; i < shape.Length; i++)
            {
                dimInfos[i].Size  = shape[i];
                dimInfos[i].Flags = dimProps != null ? (int)dimProps[i] : 0;
            }

            var builtInSensorType =
                (sensor as IBuiltInSensor)?.GetBuiltInSensorType() ?? Sensors.BuiltInSensorType.Unknown;

            return(new EventObservationSpec
            {
                SensorName = sensor.GetName(),
                CompressionType = sensor.GetCompressionType().ToString(),
                BuiltInSensorType = (int)builtInSensorType,
                DimensionInfos = dimInfos,
            });
        }
コード例 #20
0
 public override int[] GetObservationShape()
 {
     return(Sensor.GetObservationShape());
 }
コード例 #21
0
        /// <summary>
        /// Generate an ObservationProto for the sensor using the provided ObservationWriter.
        /// This is equivalent to producing an Observation and calling Observation.ToProto(),
        /// but avoid some intermediate memory allocations.
        /// </summary>
        /// <param name="sensor"></param>
        /// <param name="observationWriter"></param>
        /// <returns></returns>
        public static ObservationProto GetObservationProto(this ISensor sensor, ObservationWriter observationWriter)
        {
            var shape = sensor.GetObservationShape();
            ObservationProto observationProto = null;
            var compressionType = sensor.GetCompressionType();

            // Check capabilities if we need to concatenate PNGs
            if (compressionType == SensorCompressionType.PNG && shape.Length == 3 && shape[2] > 3)
            {
                var trainerCanHandle = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.ConcatenatedPngObservations;
                if (!trainerCanHandle)
                {
                    if (!s_HaveWarnedTrainerCapabilitiesMultiPng)
                    {
                        Debug.LogWarning(
                            $"Attached trainer doesn't support multiple PNGs. Switching to uncompressed observations for sensor {sensor.GetName()}. " +
                            "Please find the versions that work best together from our release page: " +
                            "https://github.com/Unity-Technologies/ml-agents/releases"
                            );
                        s_HaveWarnedTrainerCapabilitiesMultiPng = true;
                    }
                    compressionType = SensorCompressionType.None;
                }
            }
            // Check capabilities if we need mapping for compressed observations
            if (compressionType != SensorCompressionType.None && shape.Length == 3 && shape[2] > 3)
            {
                var trainerCanHandleMapping = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.CompressedChannelMapping;
                var isTrivialMapping        = IsTrivialMapping(sensor);
                if (!trainerCanHandleMapping && !isTrivialMapping)
                {
                    if (!s_HaveWarnedTrainerCapabilitiesMapping)
                    {
                        Debug.LogWarning(
                            $"The sensor {sensor.GetName()} is using non-trivial mapping and " +
                            "the attached trainer doesn't support compression mapping. " +
                            "Switching to uncompressed observations. " +
                            "Please find the versions that work best together from our release page: " +
                            "https://github.com/Unity-Technologies/ml-agents/releases"
                            );
                        s_HaveWarnedTrainerCapabilitiesMapping = true;
                    }
                    compressionType = SensorCompressionType.None;
                }
            }

            if (compressionType == SensorCompressionType.None)
            {
                var numFloats      = sensor.ObservationSize();
                var floatDataProto = new ObservationProto.Types.FloatData();
                // Resize the float array
                // TODO upgrade protobuf versions so that we can set the Capacity directly - see https://github.com/protocolbuffers/protobuf/pull/6530
                for (var i = 0; i < numFloats; i++)
                {
                    floatDataProto.Data.Add(0.0f);
                }

                observationWriter.SetTarget(floatDataProto.Data, sensor.GetObservationShape(), 0);
                sensor.Write(observationWriter);

                observationProto = new ObservationProto
                {
                    FloatData       = floatDataProto,
                    CompressionType = (CompressionTypeProto)SensorCompressionType.None,
                };
            }
            else
            {
                var compressedObs = sensor.GetCompressedObservation();
                if (compressedObs == null)
                {
                    throw new UnityAgentsException(
                              $"GetCompressedObservation() returned null data for sensor named {sensor.GetName()}. " +
                              "You must return a byte[]. If you don't want to use compressed observations, " +
                              "return SensorCompressionType.None from GetCompressionType()."
                              );
                }
                observationProto = new ObservationProto
                {
                    CompressedData  = ByteString.CopyFrom(compressedObs),
                    CompressionType = (CompressionTypeProto)sensor.GetCompressionType(),
                };
                var compressibleSensor = sensor as ISparseChannelSensor;
                if (compressibleSensor != null)
                {
                    observationProto.CompressedChannelMapping.AddRange(compressibleSensor.GetCompressedChannelMapping());
                }
            }
            // Add the dimension properties if any to the observationProto
            var dimensionPropertySensor = sensor as IDimensionPropertiesSensor;

            if (dimensionPropertySensor != null)
            {
                var   dimensionProperties    = dimensionPropertySensor.GetDimensionProperties();
                int[] intDimensionProperties = new int[dimensionProperties.Length];
                for (int i = 0; i < dimensionProperties.Length; i++)
                {
                    observationProto.DimensionProperties.Add((int)dimensionProperties[i]);
                }
                // Checking trainer compatibility with variable length observations
                if (dimensionProperties.Length == 2)
                {
                    if (dimensionProperties[0] == DimensionProperty.VariableSize &&
                        dimensionProperties[1] == DimensionProperty.None)
                    {
                        var trainerCanHandleVarLenObs = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.VariableLengthObservation;
                        if (!trainerCanHandleVarLenObs)
                        {
                            throw new UnityAgentsException("Variable Length Observations are not supported by the trainer");
                        }
                    }
                }
            }
            observationProto.Shape.AddRange(shape);

            // Add the observation type, if any, to the observationProto
            var typeSensor = sensor as ITypedSensor;

            if (typeSensor != null)
            {
                observationProto.ObservationType = (ObservationTypeProto)typeSensor.GetObservationType();
            }
            else
            {
                observationProto.ObservationType = ObservationTypeProto.Default;
            }
            return(observationProto);
        }
コード例 #22
0
        /// <summary>
        /// Generate an ObservationProto for the sensor using the provided ObservationWriter.
        /// This is equivalent to producing an Observation and calling Observation.ToProto(),
        /// but avoid some intermediate memory allocations.
        /// </summary>
        /// <param name="sensor"></param>
        /// <param name="observationWriter"></param>
        /// <returns></returns>
        public static ObservationProto GetObservationProto(this ISensor sensor, ObservationWriter observationWriter)
        {
            var shape = sensor.GetObservationShape();
            ObservationProto observationProto = null;
            var compressionType = sensor.GetCompressionType();

            // Check capabilities if we need to concatenate PNGs
            if (compressionType == SensorCompressionType.PNG && shape.Length == 3 && shape[2] > 3)
            {
                var trainerCanHandle = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.ConcatenatedPngObservations;
                if (!trainerCanHandle)
                {
                    if (!s_HaveWarnedAboutTrainerCapabilities)
                    {
                        Debug.LogWarning($"Attached trainer doesn't support multiple PNGs. Switching to uncompressed observations for sensor {sensor.GetName()}.");
                        s_HaveWarnedAboutTrainerCapabilities = true;
                    }
                    compressionType = SensorCompressionType.None;
                }
            }

            if (compressionType == SensorCompressionType.None)
            {
                var numFloats      = sensor.ObservationSize();
                var floatDataProto = new ObservationProto.Types.FloatData();
                // Resize the float array
                // TODO upgrade protobuf versions so that we can set the Capacity directly - see https://github.com/protocolbuffers/protobuf/pull/6530
                for (var i = 0; i < numFloats; i++)
                {
                    floatDataProto.Data.Add(0.0f);
                }

                observationWriter.SetTarget(floatDataProto.Data, sensor.GetObservationShape(), 0);
                sensor.Write(observationWriter);

                observationProto = new ObservationProto
                {
                    FloatData       = floatDataProto,
                    CompressionType = (CompressionTypeProto)SensorCompressionType.None,
                };
            }
            else
            {
                var compressedObs = sensor.GetCompressedObservation();
                if (compressedObs == null)
                {
                    throw new UnityAgentsException(
                              $"GetCompressedObservation() returned null data for sensor named {sensor.GetName()}. " +
                              "You must return a byte[]. If you don't want to use compressed observations, " +
                              "return SensorCompressionType.None from GetCompressionType()."
                              );
                }

                observationProto = new ObservationProto
                {
                    CompressedData  = ByteString.CopyFrom(compressedObs),
                    CompressionType = (CompressionTypeProto)sensor.GetCompressionType(),
                };
            }
            observationProto.Shape.AddRange(shape);
            return(observationProto);
        }