コード例 #1
0
        /// <summary>
        /// Creates a new copy of this instance of the parameter.
        /// </summary>
        /// <returns>A new instance of this parameter is returned.</returns>
        public override LayerParameterBase Clone()
        {
            MergeParameter p = new MergeParameter();

            p.Copy(this);
            return(p);
        }
コード例 #2
0
        /// <summary>
        /// Load the parameter from a binary reader.
        /// </summary>
        /// <param name="br">Specifies the binary reader.</param>
        /// <param name="bNewInstance">When <i>true</i> a new instance is created (the default), otherwise the existing instance is loaded from the binary reader.</param>
        /// <returns>Returns an instance of the parameter.</returns>
        public override object Load(System.IO.BinaryReader br, bool bNewInstance = true)
        {
            RawProto       proto = RawProto.Parse(br.ReadString());
            MergeParameter p     = FromProto(proto);

            if (!bNewInstance)
            {
                Copy(p);
            }

            return(p);
        }
コード例 #3
0
        /// <summary>
        /// Copy on parameter to another.
        /// </summary>
        /// <param name="src">Specifies the parameter to copy.</param>
        public override void Copy(LayerParameterBase src)
        {
            MergeParameter p = (MergeParameter)src;

            m_nCopyAxis               = p.m_nCopyAxis;
            m_nOrderingMajorAxis      = p.m_nOrderingMajorAxis;
            m_nCopyCount              = p.m_nCopyCount;
            m_nSrcStartIdx1           = p.m_nSrcStartIdx1;
            m_nSrcStartIdx2           = p.m_nSrcStartIdx2;
            m_nDstStartIdx1           = p.m_nDstStartIdx1;
            m_nDstStartIdx2           = p.m_nDstStartIdx2;
            m_nCopyDim1               = p.m_nCopyDim1;
            m_nCopyDim2               = p.m_nCopyDim2;
            m_nSrcSpatialDimStartIdx1 = p.m_nSrcSpatialDimStartIdx1;
            m_nDstSpatialDimStartIdx1 = p.m_nDstSpatialDimStartIdx1;
            m_nSrcSpatialDimStartIdx2 = p.m_nSrcSpatialDimStartIdx2;
            m_nDstSpatialDimStartIdx2 = p.m_nDstSpatialDimStartIdx2;
            m_nSpatialDimCopyCount    = p.m_nSpatialDimCopyCount;
            m_nDstSpatialDim          = p.m_nDstSpatialDim;
        }
コード例 #4
0
        /// <summary>
        /// Calculate the new shape based on the merge parameter settings and the specified input shapes.
        /// </summary>
        /// <param name="p">Specifies the merge parameter.</param>
        /// <param name="rgShape1">Specifies the shape of the first input.</param>
        /// <param name="rgShape2">Specifies the shape of the second input.</param>
        /// <param name="log">Specifies the output log.</param>
        /// <returns>The new shape is returned.</returns>
        public static List <int> Reshape(Log log, MergeParameter p, List <int> rgShape1, List <int> rgShape2)
        {
            while (rgShape2.Count > rgShape1.Count && rgShape2.Count > 0)
            {
                if (rgShape2[rgShape2.Count - 1] == 1)
                {
                    rgShape2.RemoveAt(rgShape2.Count - 1);
                }
            }

            while (rgShape1.Count > rgShape2.Count && rgShape1.Count > 0)
            {
                if (rgShape1[rgShape1.Count - 1] == 1)
                {
                    rgShape1.RemoveAt(rgShape1.Count - 1);
                }
            }

            log.CHECK_EQ(rgShape1.Count, rgShape2.Count, "The inputs must have the same number of axes.");
            log.CHECK_LT(p.copy_axis, rgShape1.Count, "There must be more axes than the copy axis!");

            int nSrcStartIdx1 = Utility.CanonicalAxisIndex(p.src_start_idx1, rgShape1[p.copy_axis]);
            int nSrcStartIdx2 = Utility.CanonicalAxisIndex(p.src_start_idx2, rgShape2[p.copy_axis]);
            int nDstStartIdx1 = Utility.CanonicalAxisIndex(p.dst_start_idx1, rgShape1[p.copy_axis]);
            int nDstStartIdx2 = Utility.CanonicalAxisIndex(p.dst_start_idx2, rgShape2[p.copy_axis]);

            List <int> rgNewShape = new List <int>();

            for (int i = 0; i < rgShape1.Count; i++)
            {
                rgNewShape.Add(1);
            }

            for (int i = 0; i < p.copy_axis; i++)
            {
                log.CHECK_EQ(rgShape1[i], rgShape2[i], "Inputs must have the same dimensions up to the copy axis.");
                rgNewShape[i] = rgShape1[i];
            }

            int nCopy1 = p.copy_dim1;
            int nCopy2 = p.copy_dim2;
            int nIdx   = p.copy_axis;

            rgNewShape[nIdx] = nCopy1 + nCopy2;
            nIdx++;
            rgNewShape[nIdx] = rgShape1[nIdx];
            nIdx++;

            for (int i = nIdx; i < rgNewShape.Count; i++)
            {
                if (p.m_nDstSpatialDim > 0)
                {
                    rgNewShape[i] = p.m_nDstSpatialDim;
                    break;
                }

                if (p.spatialdim_copy_count <= 0)
                {
                    log.CHECK_EQ(rgShape1[i], rgShape2[i], "Inputs must have the same dimensions after the copy axis.");
                    rgNewShape[i] = rgShape1[i];
                }
            }

            return(rgNewShape);
        }
コード例 #5
0
        /// <summary>
        /// Parses the parameter from a RawProto.
        /// </summary>
        /// <param name="rp">Specifies the RawProto to parse.</param>
        /// <returns>A new instance of the parameter is returned.</returns>
        public static MergeParameter FromProto(RawProto rp)
        {
            string         strVal;
            MergeParameter p = new MergeParameter();

//            if ((strVal = rp.FindValue("copy_axis")) != null)
//                p.copy_axis = int.Parse(strVal);

//            if ((strVal = rp.FindValue("order_major_axis")) != null)
//                p.order_major_axis = int.Parse(strVal);

            if ((strVal = rp.FindValue("copy_count")) != null)
            {
                p.copy_count = int.Parse(strVal);
            }

            if ((strVal = rp.FindValue("src_start_idx1")) != null)
            {
                p.src_start_idx1 = int.Parse(strVal);
            }

            if ((strVal = rp.FindValue("dst_start_idx1")) != null)
            {
                p.dst_start_idx1 = int.Parse(strVal);
            }

            if ((strVal = rp.FindValue("copy_dim1")) != null)
            {
                p.copy_dim1 = int.Parse(strVal);
            }

            if ((strVal = rp.FindValue("src_start_idx2")) != null)
            {
                p.src_start_idx2 = int.Parse(strVal);
            }

            if ((strVal = rp.FindValue("dst_start_idx2")) != null)
            {
                p.dst_start_idx2 = int.Parse(strVal);
            }

            if ((strVal = rp.FindValue("copy_dim2")) != null)
            {
                p.copy_dim2 = int.Parse(strVal);
            }

            if ((strVal = rp.FindValue("src_spatialdim_start_idx1")) != null)
            {
                p.src_spatialdim_start_idx1 = int.Parse(strVal);
            }

            if ((strVal = rp.FindValue("dst_spatialdim_start_idx1")) != null)
            {
                p.dst_spatialdim_start_idx1 = int.Parse(strVal);
            }

            if ((strVal = rp.FindValue("src_spatialdim_start_idx2")) != null)
            {
                p.src_spatialdim_start_idx2 = int.Parse(strVal);
            }

            if ((strVal = rp.FindValue("dst_spatialdim_start_idx2")) != null)
            {
                p.dst_spatialdim_start_idx2 = int.Parse(strVal);
            }

            if ((strVal = rp.FindValue("spatialdim_copy_count")) != null)
            {
                p.spatialdim_copy_count = int.Parse(strVal);
            }

            if ((strVal = rp.FindValue("dst_spatialdim")) != null)
            {
                p.dst_spatialdim = int.Parse(strVal);
            }

            return(p);
        }