示例#1
0
        public static ObservationProto ToProto(this Observation obs)
        {
            ObservationProto obsProto = null;

            if (obs.CompressedData != null)
            {
                // Make sure that uncompressed data is empty
                if (obs.FloatData.Count != 0)
                {
                    Debug.LogWarning("Observation has both compressed and uncompressed data set. Using compressed.");
                }

                obsProto = new ObservationProto
                {
                    CompressedData  = ByteString.CopyFrom(obs.CompressedData),
                    CompressionType = (CompressionTypeProto)obs.CompressionType,
                };
            }
            else
            {
                var floatDataProto = new ObservationProto.Types.FloatData
                {
                    Data = { obs.FloatData },
                };

                obsProto = new ObservationProto
                {
                    FloatData       = floatDataProto,
                    CompressionType = (CompressionTypeProto)obs.CompressionType,
                };
            }

            obsProto.Shape.AddRange(obs.Shape);
            return(obsProto);
        }
示例#2
0
        /// <summary>
        /// Generate an ObservationProto for the sensor using the provided WriteAdapter.
        /// This is equivalent to producing an Observation and calling Observation.ToProto(),
        /// but avoid some intermediate memory allocations.
        /// </summary>
        /// <param name="sensor"></param>
        /// <param name="writeAdapter"></param>
        /// <returns></returns>
        public static ObservationProto GetObservationProto(this ISensor sensor, WriteAdapter writeAdapter)
        {
            var shape = sensor.GetObservationShape();
            ObservationProto observationProto = null;

            if (sensor.GetCompressionType() == SensorCompressionType.None)
            {
                var numFloats      = sensor.ObservationSize();
                var floatDataProto = new ObservationProto.Types.FloatData();
                // Resize the float array
                // TODO upgrade protobuf versions so that we can set the Capacity directly - see https://github.com/protocolbuffers/protobuf/pull/6530
                for (var i = 0; i < numFloats; i++)
                {
                    floatDataProto.Data.Add(0.0f);
                }

                writeAdapter.SetTarget(floatDataProto.Data, sensor.GetObservationShape(), 0);
                sensor.Write(writeAdapter);

                observationProto = new ObservationProto
                {
                    FloatData       = floatDataProto,
                    CompressionType = (CompressionTypeProto)SensorCompressionType.None,
                };
            }
            else
            {
                var compressedObs = sensor.GetCompressedObservation();
                if (compressedObs == null)
                {
                    throw new UnityAgentsException(
                              $"GetCompressedObservation() returned null data for sensor named {sensor.GetName()}. " +
                              "You must return a byte[]. If you don't want to use compressed observations, " +
                              "return SensorCompressionType.None from GetCompressionType()."
                              );
                }

                observationProto = new ObservationProto
                {
                    CompressedData  = ByteString.CopyFrom(compressedObs),
                    CompressionType = (CompressionTypeProto)sensor.GetCompressionType(),
                };
            }
            observationProto.Shape.AddRange(shape);
            return(observationProto);
        }
示例#3
0
        /// <summary>
        /// Generate an ObservationProto for the sensor using the provided WriteAdapter.
        /// This is equivalent to producing an Observation and calling Observation.ToProto(),
        /// but avoid some intermediate memory allocations.
        /// </summary>
        /// <param name="sensor"></param>
        /// <param name="writeAdapter"></param>
        /// <returns></returns>
        public static ObservationProto GetObservationProto(this ISensor sensor, WriteAdapter writeAdapter)
        {
            var shape = sensor.GetObservationShape();
            ObservationProto observationProto = null;

            if (sensor.GetCompressionType() == SensorCompressionType.None)
            {
                var numFloats      = sensor.ObservationSize();
                var floatDataProto = new ObservationProto.Types.FloatData();
                // Resize the float array
                // TODO upgrade protobuf versions so that we can set the Capacity directly - see https://github.com/protocolbuffers/protobuf/pull/6530
                for (var i = 0; i < numFloats; i++)
                {
                    floatDataProto.Data.Add(0.0f);
                }

                writeAdapter.SetTarget(floatDataProto.Data, sensor.GetObservationShape(), 0);
                sensor.Write(writeAdapter);

                observationProto = new ObservationProto
                {
                    FloatData       = floatDataProto,
                    CompressionType = (CompressionTypeProto)SensorCompressionType.None,
                };
            }
            else
            {
                observationProto = new ObservationProto
                {
                    CompressedData  = ByteString.CopyFrom(sensor.GetCompressedObservation()),
                    CompressionType = (CompressionTypeProto)sensor.GetCompressionType(),
                };
            }
            observationProto.Shape.AddRange(shape);
            return(observationProto);
        }
示例#4
0
        /// <summary>
        /// Generate an ObservationProto for the sensor using the provided ObservationWriter.
        /// This is equivalent to producing an Observation and calling Observation.ToProto(),
        /// but avoid some intermediate memory allocations.
        /// </summary>
        /// <param name="sensor"></param>
        /// <param name="observationWriter"></param>
        /// <returns></returns>
        public static ObservationProto GetObservationProto(this ISensor sensor, ObservationWriter observationWriter)
        {
            var obsSpec = sensor.GetObservationSpec();
            var shape   = obsSpec.Shape;
            ObservationProto observationProto = null;
            var compressionSpec = sensor.GetCompressionSpec();
            var compressionType = compressionSpec.SensorCompressionType;

            // Check capabilities if we need to concatenate PNGs
            if (compressionType == SensorCompressionType.PNG && shape.Length == 3 && shape[2] > 3)
            {
                var trainerCanHandle = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.ConcatenatedPngObservations;
                if (!trainerCanHandle)
                {
                    if (!s_HaveWarnedTrainerCapabilitiesMultiPng)
                    {
                        Debug.LogWarning(
                            $"Attached trainer doesn't support multiple PNGs. Switching to uncompressed observations for sensor {sensor.GetName()}. " +
                            "Please find the versions that work best together from our release page: " +
                            "https://github.com/Unity-Technologies/ml-agents/releases"
                            );
                        s_HaveWarnedTrainerCapabilitiesMultiPng = true;
                    }
                    compressionType = SensorCompressionType.None;
                }
            }
            // Check capabilities if we need mapping for compressed observations
            if (compressionType != SensorCompressionType.None && shape.Length == 3 && shape[2] > 3)
            {
                var trainerCanHandleMapping = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.CompressedChannelMapping;
                var isTrivialMapping        = compressionSpec.IsTrivialMapping();
                if (!trainerCanHandleMapping && !isTrivialMapping)
                {
                    if (!s_HaveWarnedTrainerCapabilitiesMapping)
                    {
                        Debug.LogWarning(
                            $"The sensor {sensor.GetName()} is using non-trivial mapping and " +
                            "the attached trainer doesn't support compression mapping. " +
                            "Switching to uncompressed observations. " +
                            "Please find the versions that work best together from our release page: " +
                            "https://github.com/Unity-Technologies/ml-agents/releases"
                            );
                        s_HaveWarnedTrainerCapabilitiesMapping = true;
                    }
                    compressionType = SensorCompressionType.None;
                }
            }

            if (compressionType == SensorCompressionType.None)
            {
                var numFloats      = sensor.ObservationSize();
                var floatDataProto = new ObservationProto.Types.FloatData();
                // Resize the float array
                // TODO upgrade protobuf versions so that we can set the Capacity directly - see https://github.com/protocolbuffers/protobuf/pull/6530
                for (var i = 0; i < numFloats; i++)
                {
                    floatDataProto.Data.Add(0.0f);
                }

                observationWriter.SetTarget(floatDataProto.Data, sensor.GetObservationSpec(), 0);
                sensor.Write(observationWriter);

                observationProto = new ObservationProto
                {
                    FloatData       = floatDataProto,
                    CompressionType = (CompressionTypeProto)SensorCompressionType.None,
                };
            }
            else
            {
                var compressedObs = sensor.GetCompressedObservation();
                if (compressedObs == null)
                {
                    throw new UnityAgentsException(
                              $"GetCompressedObservation() returned null data for sensor named {sensor.GetName()}. " +
                              "You must return a byte[]. If you don't want to use compressed observations, " +
                              "return CompressionSpec.Default() from GetCompressionSpec()."
                              );
                }
                observationProto = new ObservationProto
                {
                    CompressedData  = ByteString.CopyFrom(compressedObs),
                    CompressionType = (CompressionTypeProto)sensor.GetCompressionSpec().SensorCompressionType,
                };
                if (compressionSpec.CompressedChannelMapping != null)
                {
                    observationProto.CompressedChannelMapping.AddRange(compressionSpec.CompressedChannelMapping);
                }
            }

            // Add the dimension properties to the observationProto
            var dimensionProperties = obsSpec.DimensionProperties;

            for (int i = 0; i < dimensionProperties.Length; i++)
            {
                observationProto.DimensionProperties.Add((int)dimensionProperties[i]);
            }

            // Checking trainer compatibility with variable length observations
            if (dimensionProperties == new InplaceArray <DimensionProperty>(DimensionProperty.VariableSize, DimensionProperty.None))
            {
                var trainerCanHandleVarLenObs = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.VariableLengthObservation;
                if (!trainerCanHandleVarLenObs)
                {
                    throw new UnityAgentsException("Variable Length Observations are not supported by the trainer");
                }
            }

            for (var i = 0; i < shape.Length; i++)
            {
                observationProto.Shape.Add(shape[i]);
            }

            var sensorName = sensor.GetName();

            if (!string.IsNullOrEmpty(sensorName))
            {
                observationProto.Name = sensorName;
            }

            observationProto.ObservationType = (ObservationTypeProto)obsSpec.ObservationType;
            return(observationProto);
        }
示例#5
0
        /// <summary>
        /// Generate an ObservationProto for the sensor using the provided ObservationWriter.
        /// This is equivalent to producing an Observation and calling Observation.ToProto(),
        /// but avoid some intermediate memory allocations.
        /// </summary>
        /// <param name="sensor"></param>
        /// <param name="observationWriter"></param>
        /// <returns></returns>
        public static ObservationProto GetObservationProto(this ISensor sensor, ObservationWriter observationWriter)
        {
            var shape = sensor.GetObservationShape();
            ObservationProto observationProto = null;
            var compressionType = sensor.GetCompressionType();

            // Check capabilities if we need to concatenate PNGs
            if (compressionType == SensorCompressionType.PNG && shape.Length == 3 && shape[2] > 3)
            {
                var trainerCanHandle = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.ConcatenatedPngObservations;
                if (!trainerCanHandle)
                {
                    if (!s_HaveWarnedTrainerCapabilitiesMultiPng)
                    {
                        Debug.LogWarning(
                            $"Attached trainer doesn't support multiple PNGs. Switching to uncompressed observations for sensor {sensor.GetName()}. " +
                            "Please find the versions that work best together from our release page: " +
                            "https://github.com/Unity-Technologies/ml-agents/releases"
                            );
                        s_HaveWarnedTrainerCapabilitiesMultiPng = true;
                    }
                    compressionType = SensorCompressionType.None;
                }
            }
            // Check capabilities if we need mapping for compressed observations
            if (compressionType != SensorCompressionType.None && shape.Length == 3 && shape[2] > 3)
            {
                var trainerCanHandleMapping = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.CompressedChannelMapping;
                var isTrivialMapping        = IsTrivialMapping(sensor);
                if (!trainerCanHandleMapping && !isTrivialMapping)
                {
                    if (!s_HaveWarnedTrainerCapabilitiesMapping)
                    {
                        Debug.LogWarning(
                            $"The sensor {sensor.GetName()} is using non-trivial mapping and " +
                            "the attached trainer doesn't support compression mapping. " +
                            "Switching to uncompressed observations. " +
                            "Please find the versions that work best together from our release page: " +
                            "https://github.com/Unity-Technologies/ml-agents/releases"
                            );
                        s_HaveWarnedTrainerCapabilitiesMapping = true;
                    }
                    compressionType = SensorCompressionType.None;
                }
            }

            if (compressionType == SensorCompressionType.None)
            {
                var numFloats      = sensor.ObservationSize();
                var floatDataProto = new ObservationProto.Types.FloatData();
                // Resize the float array
                // TODO upgrade protobuf versions so that we can set the Capacity directly - see https://github.com/protocolbuffers/protobuf/pull/6530
                for (var i = 0; i < numFloats; i++)
                {
                    floatDataProto.Data.Add(0.0f);
                }

                observationWriter.SetTarget(floatDataProto.Data, sensor.GetObservationShape(), 0);
                sensor.Write(observationWriter);

                observationProto = new ObservationProto
                {
                    FloatData       = floatDataProto,
                    CompressionType = (CompressionTypeProto)SensorCompressionType.None,
                };
            }
            else
            {
                var compressedObs = sensor.GetCompressedObservation();
                if (compressedObs == null)
                {
                    throw new UnityAgentsException(
                              $"GetCompressedObservation() returned null data for sensor named {sensor.GetName()}. " +
                              "You must return a byte[]. If you don't want to use compressed observations, " +
                              "return SensorCompressionType.None from GetCompressionType()."
                              );
                }
                observationProto = new ObservationProto
                {
                    CompressedData  = ByteString.CopyFrom(compressedObs),
                    CompressionType = (CompressionTypeProto)sensor.GetCompressionType(),
                };
                var compressibleSensor = sensor as ISparseChannelSensor;
                if (compressibleSensor != null)
                {
                    observationProto.CompressedChannelMapping.AddRange(compressibleSensor.GetCompressedChannelMapping());
                }
            }
            observationProto.Shape.AddRange(shape);
            return(observationProto);
        }
示例#6
0
        /// <summary>
        /// Generate an ObservationProto for the sensor using the provided ObservationWriter.
        /// This is equivalent to producing an Observation and calling Observation.ToProto(),
        /// but avoid some intermediate memory allocations.
        /// </summary>
        /// <param name="sensor"></param>
        /// <param name="observationWriter"></param>
        /// <returns></returns>
        public static ObservationProto GetObservationProto(this ISensor sensor, ObservationWriter observationWriter)
        {
            var shape = sensor.GetObservationShape();
            ObservationProto observationProto = null;
            var compressionType = sensor.GetCompressionType();

            // Check capabilities if we need to concatenate PNGs
            if (compressionType == SensorCompressionType.PNG && shape.Length == 3 && shape[2] > 3)
            {
                var trainerCanHandle = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.ConcatenatedPngObservations;
                if (!trainerCanHandle)
                {
                    if (!s_HaveWarnedAboutTrainerCapabilities)
                    {
                        Debug.LogWarning($"Attached trainer doesn't support multiple PNGs. Switching to uncompressed observations for sensor {sensor.GetName()}.");
                        s_HaveWarnedAboutTrainerCapabilities = true;
                    }
                    compressionType = SensorCompressionType.None;
                }
            }

            if (compressionType == SensorCompressionType.None)
            {
                var numFloats      = sensor.ObservationSize();
                var floatDataProto = new ObservationProto.Types.FloatData();
                // Resize the float array
                // TODO upgrade protobuf versions so that we can set the Capacity directly - see https://github.com/protocolbuffers/protobuf/pull/6530
                for (var i = 0; i < numFloats; i++)
                {
                    floatDataProto.Data.Add(0.0f);
                }

                observationWriter.SetTarget(floatDataProto.Data, sensor.GetObservationShape(), 0);
                sensor.Write(observationWriter);

                observationProto = new ObservationProto
                {
                    FloatData       = floatDataProto,
                    CompressionType = (CompressionTypeProto)SensorCompressionType.None,
                };
            }
            else
            {
                var compressedObs = sensor.GetCompressedObservation();
                if (compressedObs == null)
                {
                    throw new UnityAgentsException(
                              $"GetCompressedObservation() returned null data for sensor named {sensor.GetName()}. " +
                              "You must return a byte[]. If you don't want to use compressed observations, " +
                              "return SensorCompressionType.None from GetCompressionType()."
                              );
                }

                observationProto = new ObservationProto
                {
                    CompressedData  = ByteString.CopyFrom(compressedObs),
                    CompressionType = (CompressionTypeProto)sensor.GetCompressionType(),
                };
            }
            observationProto.Shape.AddRange(shape);
            return(observationProto);
        }