Example #1
0
        public void ModuleList()
        {
            var a    = new torch.nn.Conv1d(2, 2, 5);
            var b    = new torch.nn.Conv1d(1, 1, 3);
            var list = new torch.nn.ModuleList(a, b);

            Assert.AreEqual(a.repr, list[0].repr);
            Assert.AreNotEqual(a.repr, list[1].repr);
            Assert.AreEqual(b.repr, list[1].repr);
            Assert.Throws <PythonException>(() =>
            {
                var x = list[2];
            });
            Assert.AreEqual(2, list.Count());
            Assert.AreEqual("Conv1d(2, 2, kernel_size=(5,), stride=(1,))", list[0].repr);
            Assert.AreEqual("Conv1d(1, 1, kernel_size=(3,), stride=(1,))", list[1].repr);
            Assert.AreEqual("Conv1d(2, 2, kernel_size=(5,), stride=(1,))|Conv1d(1, 1, kernel_size=(3,), stride=(1,))", string.Join("|", list.Select(v => v.repr)));
            list.extend(b);
            Assert.AreEqual(3, list.Count());
            list.append(a);
            Assert.AreEqual(4, list.Count());
            list.insert(2, a);
            Assert.AreEqual(new[] { a, b, a, b, a }.Select(x => x.repr).ToArray(), list.Select(x => x.repr).ToArray());
            Assert.AreEqual(5, list.len());
        }
Example #2
0
        public void ModuleDict()
        {
            var a    = new torch.nn.Conv1d(2, 2, 5);
            var b    = new torch.nn.Conv1d(1, 1, 3);
            var dict = new torch.nn.ModuleDict(
                ("a", a),
                ("b", b)
                );

            Assert.AreEqual(a.repr, dict["a"].repr);
            Assert.AreNotEqual(a.repr, dict["b"].repr);
            Assert.AreEqual(b.repr, dict["b"].repr);
            Assert.Throws <PythonException>(() =>
            {
                var x = dict["nothing"];
            });
            var items = dict.items().ToArray();

            Assert.AreEqual(2, items.Length);
            Assert.AreEqual("a", items[0].Item1);
            Assert.AreEqual("b", items[1].Item1);
            Assert.AreEqual("Conv1d(2, 2, kernel_size=(5,), stride=(1,))", items[0].Item2.repr);
            Assert.AreEqual("Conv1d(1, 1, kernel_size=(3,), stride=(1,))", items[1].Item2.repr);
            Assert.AreEqual("a, b", string.Join(", ", dict.keys()));
            Assert.AreEqual("Conv1d(2, 2, kernel_size=(5,), stride=(1,))|Conv1d(1, 1, kernel_size=(3,), stride=(1,))", string.Join("|", dict.values().Select(v => v.repr)));
            Assert.AreEqual("a, b", string.Join(", ", dict.Select(x => x.Item1)));
            Assert.AreEqual("Conv1d(2, 2, kernel_size=(5,), stride=(1,))|Conv1d(1, 1, kernel_size=(3,), stride=(1,))", string.Join("|", dict.Select(x => x.Item2.repr)));
            Assert.AreEqual("Conv1d(2, 2, kernel_size=(5,), stride=(1,))", dict.pop("a").repr);
            Assert.AreEqual(1, dict.items().Count());
            dict.clear();
            Assert.AreEqual(0, dict.items().Count());
            dict.update(("a", a), ("b", b));
            Assert.AreEqual(2, dict.items().Count());
            var a1 = new torch.nn.Conv1d(dict["a"]);

            Assert.AreEqual(a.repr, a1.repr);
            Assert.AreEqual(2, dict.len());
        }