示例#1
0
        public void ParameterDict()
        {
            var a    = new torch.nn.Parameter(new[] { 0.5 });
            var b    = new torch.nn.Parameter(new[] { 0.3 });
            var dict = new torch.nn.ParameterDict(
                ("a", a),
                ("b", b)
                );

            Assert.AreEqual(a.repr, dict["a"].repr);
            Assert.AreNotEqual(a.repr, dict["b"].repr);
            Assert.AreEqual(b.repr, dict["b"].repr);
            Assert.Throws <PythonException>(() =>
            {
                var x = dict["nothing"];
            });
            var items = dict.items().ToArray();

            Assert.AreEqual(2, items.Length);
            Assert.AreEqual("a", items[0].Item1);
            Assert.AreEqual("b", items[1].Item1);
            Assert.AreEqual("Parameter containing:\ntensor([0.5000], dtype=torch.float64, requires_grad=True)", items[0].Item2.repr);
            Assert.AreEqual("Parameter containing:\ntensor([0.3000], dtype=torch.float64, requires_grad=True)", items[1].Item2.repr);
            Assert.AreEqual("a, b", string.Join(", ", dict.keys()));
            Assert.AreEqual("Parameter containing:\ntensor([0.5000], dtype=torch.float64, requires_grad=True)|Parameter containing:\ntensor([0.3000], dtype=torch.float64, requires_grad=True)", string.Join("|", dict.values().Select(v => v.repr)));
            Assert.AreEqual("a, b", string.Join(", ", dict.Select(x => x.Item1)));
            Assert.AreEqual("Parameter containing:\ntensor([0.5000], dtype=torch.float64, requires_grad=True)|Parameter containing:\ntensor([0.3000], dtype=torch.float64, requires_grad=True)", string.Join("|", dict.Select(x => x.Item2.repr)));
            Assert.AreEqual("Parameter containing:\ntensor([0.5000], dtype=torch.float64, requires_grad=True)", dict.pop("a").repr);
            Assert.AreEqual(1, dict.items().Count());
            dict.clear();
            Assert.AreEqual(0, dict.items().Count());
            dict.update(("a", a), ("b", b));
            Assert.AreEqual(2, dict.items().Count());
            Assert.AreEqual(2, dict.len());
        }
示例#2
0
        public void ParameterTest()
        {
            var Parameter = torch.dynamic_self.GetAttr("nn").GetAttr("Parameter");

            Console.WriteLine(Parameter.ToString());
            var x  = torch.tensor(new double[] { 1, 2, 3 });
            var p  = Parameter(x.PyObject, requires_grad: false);
            var p1 = (torch.self.GetAttr("nn") as PyObject).InvokeMethod("Parameter", new PyTuple(new PyObject[] { x.PyObject }), Py.kw("requires_grad", new PyObject(Runtime.PyTrue)));

            Console.WriteLine(p.ToString());
            Console.WriteLine(p1.ToString());
            //
            var p2 = new torch.nn.Parameter(x, true);

            Assert.AreEqual(p1.ToString(), p2.ToString());
        }
示例#3
0
        public void ParameterList()
        {
            var a    = new torch.nn.Parameter(new[] { 0.5 });
            var b    = new torch.nn.Parameter(new[] { 0.3 });
            var list = new torch.nn.ParameterList(a, b);

            Assert.AreEqual(a.repr, list[0].repr);
            Assert.AreNotEqual(a.repr, list[1].repr);
            Assert.AreEqual(b.repr, list[1].repr);
            Assert.Throws <PythonException>(() =>
            {
                var x = list[2];
            });
            Assert.AreEqual(2, list.Count());
            Assert.AreEqual("Parameter containing:\ntensor([0.5000], dtype=torch.float64, requires_grad=True)", list[0].repr);
            Assert.AreEqual("Parameter containing:\ntensor([0.3000], dtype=torch.float64, requires_grad=True)", list[1].repr);
            Assert.AreEqual("Parameter containing:\ntensor([0.5000], dtype=torch.float64, requires_grad=True)|Parameter containing:\ntensor([0.3000], dtype=torch.float64, requires_grad=True)", string.Join("|", list.Select(v => v.repr)));
            list.extend(b);
            Assert.AreEqual(3, list.Count());
            list.append(a);
            Assert.AreEqual(4, list.Count());
            Assert.AreEqual(4, list.len());
            Assert.AreEqual(new[] { a, b, b, a }.Select(x => x.repr).ToArray(), list.Select(x => x.repr).ToArray());
        }