Ejemplo n.º 1
0
        private static void Move(IEnumerable <string> parameters)
        {
            var movableSensor = CurrentSensor as IMovable;

            if (movableSensor != null)
            {
                if (parameters.Count() == 2)
                {
                    try
                    {
                        float x, y;
                        x = float.Parse(parameters.First());
                        y = float.Parse(parameters.Skip(1).First());
                        movableSensor.Move(x, y);
                        Console.WriteLine($"Sensor {CurrentSensor.GetName()} moved to x={x} y={y}");
                    }
                    catch (FormatException)
                    {
                        Console.WriteLine($"Wrong floating-point format according to current culture. Try move {9.99F} {9.99F}");
                    }
                }
                else
                {
                    Console.WriteLine($"Command move requires two parameters.\n\tSyntax: move {9.99F} {9.99F}");
                }
            }
            else
            {
                Console.WriteLine("Please select a IMovable sensor to use the move function");
            }
        }
Ejemplo n.º 2
0
        /// <summary>
        /// Checks that the shape of the rank 2 observation input placeholder is the same as the corresponding sensor.
        /// </summary>
        /// <param name="tensorProxy">The tensor that is expected by the model</param>
        /// <param name="sensor">The sensor that produces the visual observation.</param>
        /// <returns>
        /// If the Check failed, returns a string containing information about why the
        /// check failed. If the check passed, returns null.
        /// </returns>
        static FailedCheck CheckRankOneObsShape(
            TensorProxy tensorProxy, ISensor sensor)
        {
            var shape  = sensor.GetObservationSpec().Shape;
            var dim1Bp = shape[0];
            var dim1T  = tensorProxy.Channels;
            var dim2T  = tensorProxy.Width;
            var dim3T  = tensorProxy.Height;

            if ((dim1Bp != dim1T))
            {
                var proxyDimStr = $"[?x{dim1T}]";
                if (dim2T > 1)
                {
                    proxyDimStr = $"[?x{dim1T}x{dim2T}]";
                }
                if (dim3T > 1)
                {
                    proxyDimStr = $"[?x{dim3T}x{dim2T}x{dim1T}]";
                }
                return(FailedCheck.Warning($"An Observation of the model does not match. " +
                                           $"Received TensorProxy of shape [?x{dim1Bp}] but " +
                                           $"was expecting {proxyDimStr} for the {sensor.GetName()} Sensor."
                                           ));
            }
            return(null);
        }
Ejemplo n.º 3
0
        /// <summary>
        ///
        /// </summary>
        /// <param name="wrapped">The wrapped sensor</param>
        /// <param name="numStackedObservations">Number of stacked observations to keep</param>
        public StackingSensor(ISensor wrapped, int numStackedObservations)
        {
            // TODO ensure numStackedObservations > 1
            m_WrappedSensor          = wrapped;
            m_NumStackedObservations = numStackedObservations;

            m_Name = $"StackingSensor_size{numStackedObservations}_{wrapped.GetName()}";

            var shape = wrapped.GetObservationShape();

            m_Shape = new int[shape.Length];

            m_UnstackedObservationSize = wrapped.ObservationSize();
            for (int d = 0; d < shape.Length; d++)
            {
                m_Shape[d] = shape[d];
            }

            // TODO support arbitrary stacking dimension
            m_Shape[0]           *= numStackedObservations;
            m_StackedObservations = new float[numStackedObservations][];
            for (var i = 0; i < numStackedObservations; i++)
            {
                m_StackedObservations[i] = new float[m_UnstackedObservationSize];
            }
        }
Ejemplo n.º 4
0
        public static EventObservationSpec FromSensor(ISensor sensor)
        {
            var obsSpec  = sensor.GetObservationSpec();
            var shape    = obsSpec.Shape;
            var dimProps = obsSpec.DimensionProperties;
            var dimInfos = new EventObservationDimensionInfo[shape.Length];

            for (var i = 0; i < shape.Length; i++)
            {
                dimInfos[i].Size  = shape[i];
                dimInfos[i].Flags = (int)dimProps[i];
            }

            var builtInSensorType =
                (sensor as IBuiltInSensor)?.GetBuiltInSensorType() ?? Sensors.BuiltInSensorType.Unknown;

            return(new EventObservationSpec
            {
                SensorName = sensor.GetName(),
                CompressionType = sensor.GetCompressionSpec().SensorCompressionType.ToString(),
                BuiltInSensorType = (int)builtInSensorType,
                ObservationType = (int)obsSpec.ObservationType,
                DimensionInfos = dimInfos,
            });
        }
Ejemplo n.º 5
0
        /// <summary>
        /// Initializes the sensor.
        /// </summary>
        /// <param name="wrapped">The wrapped sensor.</param>
        /// <param name="numStackedObservations">Number of stacked observations to keep.</param>
        public StackingSensor(ISensor wrapped, int numStackedObservations)
        {
            // TODO ensure numStackedObservations > 1
            m_WrappedSensor          = wrapped;
            m_NumStackedObservations = numStackedObservations;

            m_Name = $"StackingSensor_size{numStackedObservations}_{wrapped.GetName()}";

            if (wrapped.GetCompressionType() != SensorCompressionType.None)
            {
                throw new UnityAgentsException("StackingSensor doesn't support compressed observations.'");
            }

            var shape = wrapped.GetObservationShape();

            if (shape.Length != 1)
            {
                throw new UnityAgentsException("Only 1-D observations are supported by StackingSensor");
            }
            m_Shape = new int[shape.Length];

            m_UnstackedObservationSize = wrapped.ObservationSize();
            for (int d = 0; d < shape.Length; d++)
            {
                m_Shape[d] = shape[d];
            }

            // TODO support arbitrary stacking dimension
            m_Shape[0]           *= numStackedObservations;
            m_StackedObservations = new float[numStackedObservations][];
            for (var i = 0; i < numStackedObservations; i++)
            {
                m_StackedObservations[i] = new float[m_UnstackedObservationSize];
            }
        }
Ejemplo n.º 6
0
 /// <summary>
 ///     Adds information about sensor quality.
 /// </summary>
 /// <param name="sensor">Sensor instance</param>
 /// <param name="result">Map of results</param>
 private static void AddToResult(ISensor sensor, IDictionary <string, string> result)
 {
     if (sensor != null)
     {
         result.Add(
             sensor.GetName(),
             sensor.CalculateQuality()
             );
     }
 }
Ejemplo n.º 7
0
        /// <summary>
        /// Initializes the sensor.
        /// </summary>
        /// <param name="wrapped">The wrapped sensor.</param>
        /// <param name="numStackedObservations">Number of stacked observations to keep.</param>
        public StackingSensor(ISensor wrapped, int numStackedObservations)
        {
            // TODO ensure numStackedObservations > 1
            m_WrappedSensor          = wrapped;
            m_NumStackedObservations = numStackedObservations;

            m_Name = $"StackingSensor_size{numStackedObservations}_{wrapped.GetName()}";

            m_WrappedSpec = wrapped.GetObservationSpec();

            m_UnstackedObservationSize = wrapped.ObservationSize();

            // Set up the cached observation spec for the StackingSensor
            var newShape = m_WrappedSpec.Shape;

            // TODO support arbitrary stacking dimension
            newShape[newShape.Length - 1] *= numStackedObservations;
            m_ObservationSpec              = new ObservationSpec(
                newShape, m_WrappedSpec.DimensionProperties, m_WrappedSpec.ObservationType
                );

            // Initialize uncompressed buffer anyway in case python trainer does not
            // support the compression mapping and has to fall back to uncompressed obs.
            m_StackedObservations = new float[numStackedObservations][];
            for (var i = 0; i < numStackedObservations; i++)
            {
                m_StackedObservations[i] = new float[m_UnstackedObservationSize];
            }

            if (m_WrappedSensor.GetCompressionSpec().SensorCompressionType != SensorCompressionType.None)
            {
                m_StackedCompressedObservations = new byte[numStackedObservations][];
                m_EmptyCompressedObservation    = CreateEmptyPNG();
                for (var i = 0; i < numStackedObservations; i++)
                {
                    m_StackedCompressedObservations[i] = m_EmptyCompressedObservation;
                }
                m_CompressionMapping = ConstructStackedCompressedChannelMapping(wrapped);
            }

            if (m_WrappedSpec.Rank != 1)
            {
                var wrappedShape = m_WrappedSpec.Shape;
                m_tensorShape = new TensorShape(0, wrappedShape[0], wrappedShape[1], wrappedShape[2]);
            }
        }
Ejemplo n.º 8
0
        public static EventObservationSpec FromSensor(ISensor sensor)
        {
            var shape    = sensor.GetObservationShape();
            var dimInfos = new EventObservationDimensionInfo[shape.Length];

            for (var i = 0; i < shape.Length; i++)
            {
                dimInfos[i].Size = shape[i];
                // TODO copy flags when we have them
            }

            return(new EventObservationSpec
            {
                SensorName = sensor.GetName(),
                CompressionType = sensor.GetCompressionType().ToString(),
                DimensionInfos = dimInfos,
            });
        }
Ejemplo n.º 9
0
        /// <summary>
        /// Initializes the sensor.
        /// </summary>
        /// <param name="wrapped">The wrapped sensor.</param>
        /// <param name="numStackedObservations">Number of stacked observations to keep.</param>
        public StackingSensor(ISensor wrapped, int numStackedObservations)
        {
            // TODO ensure numStackedObservations > 1
            m_WrappedSensor          = wrapped;
            m_NumStackedObservations = numStackedObservations;

            m_Name = $"StackingSensor_size{numStackedObservations}_{wrapped.GetName()}";

            m_WrappedShape = wrapped.GetObservationShape();
            m_Shape        = new int[m_WrappedShape.Length];

            m_UnstackedObservationSize = wrapped.ObservationSize();
            for (int d = 0; d < m_WrappedShape.Length; d++)
            {
                m_Shape[d] = m_WrappedShape[d];
            }

            // TODO support arbitrary stacking dimension
            m_Shape[m_Shape.Length - 1] *= numStackedObservations;

            // Initialize uncompressed buffer anyway in case python trainer does not
            // support the compression mapping and has to fall back to uncompressed obs.
            m_StackedObservations = new float[numStackedObservations][];
            for (var i = 0; i < numStackedObservations; i++)
            {
                m_StackedObservations[i] = new float[m_UnstackedObservationSize];
            }

            if (m_WrappedSensor.GetCompressionType() != SensorCompressionType.None)
            {
                m_StackedCompressedObservations = new byte[numStackedObservations][];
                m_EmptyCompressedObservation    = CreateEmptyPNG();
                for (var i = 0; i < numStackedObservations; i++)
                {
                    m_StackedCompressedObservations[i] = m_EmptyCompressedObservation;
                }
                m_CompressionMapping = ConstructStackedCompressedChannelMapping(wrapped);
            }

            if (m_Shape.Length != 1)
            {
                m_tensorShape = new TensorShape(0, m_WrappedShape[0], m_WrappedShape[1], m_WrappedShape[2]);
            }
        }
Ejemplo n.º 10
0
        private static void AdjustHeight(IEnumerable <string> parameters)
        {
            var heightAdjustableSensor = CurrentSensor as IHeightAdjustable;

            if (heightAdjustableSensor == null)
            {
                Console.WriteLine($"Current sensor {CurrentSensor?.GetName() ?? "null"} is not height adjustable");
                return;
            }

            if (parameters.Count() == 2)
            {
                var direction = parameters.First();
                var height    = parameters.Skip(1).First();
                try
                {
                    switch (direction)
                    {
                    case "raise":
                        heightAdjustableSensor.Raise(float.Parse(height));
                        Console.WriteLine($"{CurrentSensor.GetName()} raised.");
                        break;

                    case "lower":
                        heightAdjustableSensor.Lower(float.Parse(height));
                        Console.WriteLine($"{CurrentSensor.GetName()} lowered.");
                        break;

                    default:
                        Console.WriteLine("First parameter to height adjustment must be 'raise' or 'lower'");
                        break;
                    }
                }
                catch (FormatException)
                {
                    Console.WriteLine($"Second parameter {height} is not a valid float according to current culture. Try {9.99F}");
                }
            }
            else
            {
                Console.WriteLine($"Parameter mismatch.\n\tSyntax: adjustheight [raise|lower] {9.99F}");
            }
        }
Ejemplo n.º 11
0
        public static EventObservationSpec FromSensor(ISensor sensor)
        {
            var shape    = sensor.GetObservationShape();
            var dimInfos = new EventObservationDimensionInfo[shape.Length];

            for (var i = 0; i < shape.Length; i++)
            {
                dimInfos[i].Size = shape[i];
                // TODO copy flags when we have them
            }

            var builtInSensorType =
                (sensor as IBuiltInSensor)?.GetBuiltInSensorType() ?? Sensors.BuiltInSensorType.Unknown;

            return(new EventObservationSpec
            {
                SensorName = sensor.GetName(),
                CompressionType = sensor.GetCompressionType().ToString(),
                BuiltInSensorType = (int)builtInSensorType,
                DimensionInfos = dimInfos,
            });
        }
Ejemplo n.º 12
0
        /// <summary>
        /// Checks that the shape of the visual observation input placeholder is the same as the corresponding sensor.
        /// </summary>
        /// <param name="tensorProxy">The tensor that is expected by the model</param>
        /// <param name="sensor">The sensor that produces the visual observation.</param>
        /// <returns>
        /// If the Check failed, returns a string containing information about why the
        /// check failed. If the check passed, returns null.
        /// </returns>
        static FailedCheck CheckVisualObsShape(
            TensorProxy tensorProxy, ISensor sensor)
        {
            var shape    = sensor.GetObservationSpec().Shape;
            var heightBp = shape[0];
            var widthBp  = shape[1];
            var pixelBp  = shape[2];
            var heightT  = tensorProxy.Height;
            var widthT   = tensorProxy.Width;
            var pixelT   = tensorProxy.Channels;

            if ((widthBp != widthT) || (heightBp != heightT) || (pixelBp != pixelT))
            {
                return(FailedCheck.Warning($"The visual Observation of the model does not match. " +
                                           $"Received TensorProxy of shape [?x{widthBp}x{heightBp}x{pixelBp}] but " +
                                           $"was expecting [?x{widthT}x{heightT}x{pixelT}] for the {sensor.GetName()} Sensor."
                                           ));
            }
            return(null);
        }
Ejemplo n.º 13
0
        /// <summary>
        /// Generate an ObservationProto for the sensor using the provided ObservationWriter.
        /// This is equivalent to producing an Observation and calling Observation.ToProto(),
        /// but avoid some intermediate memory allocations.
        /// </summary>
        /// <param name="sensor"></param>
        /// <param name="observationWriter"></param>
        /// <returns></returns>
        public static ObservationProto GetObservationProto(this ISensor sensor, ObservationWriter observationWriter)
        {
            var obsSpec = sensor.GetObservationSpec();
            var shape   = obsSpec.Shape;
            ObservationProto observationProto = null;
            var compressionSpec = sensor.GetCompressionSpec();
            var compressionType = compressionSpec.SensorCompressionType;

            // Check capabilities if we need to concatenate PNGs
            if (compressionType == SensorCompressionType.PNG && shape.Length == 3 && shape[2] > 3)
            {
                var trainerCanHandle = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.ConcatenatedPngObservations;
                if (!trainerCanHandle)
                {
                    if (!s_HaveWarnedTrainerCapabilitiesMultiPng)
                    {
                        Debug.LogWarning(
                            $"Attached trainer doesn't support multiple PNGs. Switching to uncompressed observations for sensor {sensor.GetName()}. " +
                            "Please find the versions that work best together from our release page: " +
                            "https://github.com/Unity-Technologies/ml-agents/releases"
                            );
                        s_HaveWarnedTrainerCapabilitiesMultiPng = true;
                    }
                    compressionType = SensorCompressionType.None;
                }
            }
            // Check capabilities if we need mapping for compressed observations
            if (compressionType != SensorCompressionType.None && shape.Length == 3 && shape[2] > 3)
            {
                var trainerCanHandleMapping = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.CompressedChannelMapping;
                var isTrivialMapping        = compressionSpec.IsTrivialMapping();
                if (!trainerCanHandleMapping && !isTrivialMapping)
                {
                    if (!s_HaveWarnedTrainerCapabilitiesMapping)
                    {
                        Debug.LogWarning(
                            $"The sensor {sensor.GetName()} is using non-trivial mapping and " +
                            "the attached trainer doesn't support compression mapping. " +
                            "Switching to uncompressed observations. " +
                            "Please find the versions that work best together from our release page: " +
                            "https://github.com/Unity-Technologies/ml-agents/releases"
                            );
                        s_HaveWarnedTrainerCapabilitiesMapping = true;
                    }
                    compressionType = SensorCompressionType.None;
                }
            }

            if (compressionType == SensorCompressionType.None)
            {
                var numFloats      = sensor.ObservationSize();
                var floatDataProto = new ObservationProto.Types.FloatData();
                // Resize the float array
                // TODO upgrade protobuf versions so that we can set the Capacity directly - see https://github.com/protocolbuffers/protobuf/pull/6530
                for (var i = 0; i < numFloats; i++)
                {
                    floatDataProto.Data.Add(0.0f);
                }

                observationWriter.SetTarget(floatDataProto.Data, sensor.GetObservationSpec(), 0);
                sensor.Write(observationWriter);

                observationProto = new ObservationProto
                {
                    FloatData       = floatDataProto,
                    CompressionType = (CompressionTypeProto)SensorCompressionType.None,
                };
            }
            else
            {
                var compressedObs = sensor.GetCompressedObservation();
                if (compressedObs == null)
                {
                    throw new UnityAgentsException(
                              $"GetCompressedObservation() returned null data for sensor named {sensor.GetName()}. " +
                              "You must return a byte[]. If you don't want to use compressed observations, " +
                              "return CompressionSpec.Default() from GetCompressionSpec()."
                              );
                }
                observationProto = new ObservationProto
                {
                    CompressedData  = ByteString.CopyFrom(compressedObs),
                    CompressionType = (CompressionTypeProto)sensor.GetCompressionSpec().SensorCompressionType,
                };
                if (compressionSpec.CompressedChannelMapping != null)
                {
                    observationProto.CompressedChannelMapping.AddRange(compressionSpec.CompressedChannelMapping);
                }
            }

            // Add the dimension properties to the observationProto
            var dimensionProperties = obsSpec.DimensionProperties;

            for (int i = 0; i < dimensionProperties.Length; i++)
            {
                observationProto.DimensionProperties.Add((int)dimensionProperties[i]);
            }

            // Checking trainer compatibility with variable length observations
            if (dimensionProperties == new InplaceArray <DimensionProperty>(DimensionProperty.VariableSize, DimensionProperty.None))
            {
                var trainerCanHandleVarLenObs = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.VariableLengthObservation;
                if (!trainerCanHandleVarLenObs)
                {
                    throw new UnityAgentsException("Variable Length Observations are not supported by the trainer");
                }
            }

            for (var i = 0; i < shape.Length; i++)
            {
                observationProto.Shape.Add(shape[i]);
            }

            var sensorName = sensor.GetName();

            if (!string.IsNullOrEmpty(sensorName))
            {
                observationProto.Name = sensorName;
            }

            observationProto.ObservationType = (ObservationTypeProto)obsSpec.ObservationType;
            return(observationProto);
        }
Ejemplo n.º 14
0
        /// <summary>
        /// Generate an ObservationProto for the sensor using the provided WriteAdapter.
        /// This is equivalent to producing an Observation and calling Observation.ToProto(),
        /// but avoid some intermediate memory allocations.
        /// </summary>
        /// <param name="sensor"></param>
        /// <param name="writeAdapter"></param>
        /// <returns></returns>
        public static ObservationProto GetObservationProto(this ISensor sensor, WriteAdapter writeAdapter)
        {
            var shape = sensor.GetObservationShape();
            ObservationProto observationProto = null;

            if (sensor.GetCompressionType() == SensorCompressionType.None)
            {
                var numFloats      = sensor.ObservationSize();
                var floatDataProto = new ObservationProto.Types.FloatData();
                // Resize the float array
                // TODO upgrade protobuf versions so that we can set the Capacity directly - see https://github.com/protocolbuffers/protobuf/pull/6530
                for (var i = 0; i < numFloats; i++)
                {
                    floatDataProto.Data.Add(0.0f);
                }

                writeAdapter.SetTarget(floatDataProto.Data, sensor.GetObservationShape(), 0);
                sensor.Write(writeAdapter);

                observationProto = new ObservationProto
                {
                    FloatData       = floatDataProto,
                    CompressionType = (CompressionTypeProto)SensorCompressionType.None,
                };
            }
            else
            {
                var compressedObs = sensor.GetCompressedObservation();
                if (compressedObs == null)
                {
                    throw new UnityAgentsException(
                              $"GetCompressedObservation() returned null data for sensor named {sensor.GetName()}. " +
                              "You must return a byte[]. If you don't want to use compressed observations, " +
                              "return SensorCompressionType.None from GetCompressionType()."
                              );
                }

                observationProto = new ObservationProto
                {
                    CompressedData  = ByteString.CopyFrom(compressedObs),
                    CompressionType = (CompressionTypeProto)sensor.GetCompressionType(),
                };
            }
            observationProto.Shape.AddRange(shape);
            return(observationProto);
        }
Ejemplo n.º 15
0
        /// <summary>
        /// Generate an ObservationProto for the sensor using the provided ObservationWriter.
        /// This is equivalent to producing an Observation and calling Observation.ToProto(),
        /// but avoid some intermediate memory allocations.
        /// </summary>
        /// <param name="sensor"></param>
        /// <param name="observationWriter"></param>
        /// <returns></returns>
        public static ObservationProto GetObservationProto(this ISensor sensor, ObservationWriter observationWriter)
        {
            var shape = sensor.GetObservationShape();
            ObservationProto observationProto = null;
            var compressionType = sensor.GetCompressionType();

            // Check capabilities if we need to concatenate PNGs
            if (compressionType == SensorCompressionType.PNG && shape.Length == 3 && shape[2] > 3)
            {
                var trainerCanHandle = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.ConcatenatedPngObservations;
                if (!trainerCanHandle)
                {
                    if (!s_HaveWarnedTrainerCapabilitiesMultiPng)
                    {
                        Debug.LogWarning(
                            $"Attached trainer doesn't support multiple PNGs. Switching to uncompressed observations for sensor {sensor.GetName()}. " +
                            "Please find the versions that work best together from our release page: " +
                            "https://github.com/Unity-Technologies/ml-agents/releases"
                            );
                        s_HaveWarnedTrainerCapabilitiesMultiPng = true;
                    }
                    compressionType = SensorCompressionType.None;
                }
            }
            // Check capabilities if we need mapping for compressed observations
            if (compressionType != SensorCompressionType.None && shape.Length == 3 && shape[2] > 3)
            {
                var trainerCanHandleMapping = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.CompressedChannelMapping;
                var isTrivialMapping        = IsTrivialMapping(sensor);
                if (!trainerCanHandleMapping && !isTrivialMapping)
                {
                    if (!s_HaveWarnedTrainerCapabilitiesMapping)
                    {
                        Debug.LogWarning(
                            $"The sensor {sensor.GetName()} is using non-trivial mapping and " +
                            "the attached trainer doesn't support compression mapping. " +
                            "Switching to uncompressed observations. " +
                            "Please find the versions that work best together from our release page: " +
                            "https://github.com/Unity-Technologies/ml-agents/releases"
                            );
                        s_HaveWarnedTrainerCapabilitiesMapping = true;
                    }
                    compressionType = SensorCompressionType.None;
                }
            }

            if (compressionType == SensorCompressionType.None)
            {
                var numFloats      = sensor.ObservationSize();
                var floatDataProto = new ObservationProto.Types.FloatData();
                // Resize the float array
                // TODO upgrade protobuf versions so that we can set the Capacity directly - see https://github.com/protocolbuffers/protobuf/pull/6530
                for (var i = 0; i < numFloats; i++)
                {
                    floatDataProto.Data.Add(0.0f);
                }

                observationWriter.SetTarget(floatDataProto.Data, sensor.GetObservationShape(), 0);
                sensor.Write(observationWriter);

                observationProto = new ObservationProto
                {
                    FloatData       = floatDataProto,
                    CompressionType = (CompressionTypeProto)SensorCompressionType.None,
                };
            }
            else
            {
                var compressedObs = sensor.GetCompressedObservation();
                if (compressedObs == null)
                {
                    throw new UnityAgentsException(
                              $"GetCompressedObservation() returned null data for sensor named {sensor.GetName()}. " +
                              "You must return a byte[]. If you don't want to use compressed observations, " +
                              "return SensorCompressionType.None from GetCompressionType()."
                              );
                }
                observationProto = new ObservationProto
                {
                    CompressedData  = ByteString.CopyFrom(compressedObs),
                    CompressionType = (CompressionTypeProto)sensor.GetCompressionType(),
                };
                var compressibleSensor = sensor as ISparseChannelSensor;
                if (compressibleSensor != null)
                {
                    observationProto.CompressedChannelMapping.AddRange(compressibleSensor.GetCompressedChannelMapping());
                }
            }
            observationProto.Shape.AddRange(shape);
            return(observationProto);
        }
Ejemplo n.º 16
0
        /// <summary>
        /// Generate an ObservationProto for the sensor using the provided ObservationWriter.
        /// This is equivalent to producing an Observation and calling Observation.ToProto(),
        /// but avoid some intermediate memory allocations.
        /// </summary>
        /// <param name="sensor"></param>
        /// <param name="observationWriter"></param>
        /// <returns></returns>
        public static ObservationProto GetObservationProto(this ISensor sensor, ObservationWriter observationWriter)
        {
            var shape = sensor.GetObservationShape();
            ObservationProto observationProto = null;
            var compressionType = sensor.GetCompressionType();

            // Check capabilities if we need to concatenate PNGs
            if (compressionType == SensorCompressionType.PNG && shape.Length == 3 && shape[2] > 3)
            {
                var trainerCanHandle = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.ConcatenatedPngObservations;
                if (!trainerCanHandle)
                {
                    if (!s_HaveWarnedAboutTrainerCapabilities)
                    {
                        Debug.LogWarning($"Attached trainer doesn't support multiple PNGs. Switching to uncompressed observations for sensor {sensor.GetName()}.");
                        s_HaveWarnedAboutTrainerCapabilities = true;
                    }
                    compressionType = SensorCompressionType.None;
                }
            }

            if (compressionType == SensorCompressionType.None)
            {
                var numFloats      = sensor.ObservationSize();
                var floatDataProto = new ObservationProto.Types.FloatData();
                // Resize the float array
                // TODO upgrade protobuf versions so that we can set the Capacity directly - see https://github.com/protocolbuffers/protobuf/pull/6530
                for (var i = 0; i < numFloats; i++)
                {
                    floatDataProto.Data.Add(0.0f);
                }

                observationWriter.SetTarget(floatDataProto.Data, sensor.GetObservationShape(), 0);
                sensor.Write(observationWriter);

                observationProto = new ObservationProto
                {
                    FloatData       = floatDataProto,
                    CompressionType = (CompressionTypeProto)SensorCompressionType.None,
                };
            }
            else
            {
                var compressedObs = sensor.GetCompressedObservation();
                if (compressedObs == null)
                {
                    throw new UnityAgentsException(
                              $"GetCompressedObservation() returned null data for sensor named {sensor.GetName()}. " +
                              "You must return a byte[]. If you don't want to use compressed observations, " +
                              "return SensorCompressionType.None from GetCompressionType()."
                              );
                }

                observationProto = new ObservationProto
                {
                    CompressedData  = ByteString.CopyFrom(compressedObs),
                    CompressionType = (CompressionTypeProto)sensor.GetCompressionType(),
                };
            }
            observationProto.Shape.AddRange(shape);
            return(observationProto);
        }