public void ModuleList() { var a = new torch.nn.Conv1d(2, 2, 5); var b = new torch.nn.Conv1d(1, 1, 3); var list = new torch.nn.ModuleList(a, b); Assert.AreEqual(a.repr, list[0].repr); Assert.AreNotEqual(a.repr, list[1].repr); Assert.AreEqual(b.repr, list[1].repr); Assert.Throws <PythonException>(() => { var x = list[2]; }); Assert.AreEqual(2, list.Count()); Assert.AreEqual("Conv1d(2, 2, kernel_size=(5,), stride=(1,))", list[0].repr); Assert.AreEqual("Conv1d(1, 1, kernel_size=(3,), stride=(1,))", list[1].repr); Assert.AreEqual("Conv1d(2, 2, kernel_size=(5,), stride=(1,))|Conv1d(1, 1, kernel_size=(3,), stride=(1,))", string.Join("|", list.Select(v => v.repr))); list.extend(b); Assert.AreEqual(3, list.Count()); list.append(a); Assert.AreEqual(4, list.Count()); list.insert(2, a); Assert.AreEqual(new[] { a, b, a, b, a }.Select(x => x.repr).ToArray(), list.Select(x => x.repr).ToArray()); Assert.AreEqual(5, list.len()); }
public void ModuleDict() { var a = new torch.nn.Conv1d(2, 2, 5); var b = new torch.nn.Conv1d(1, 1, 3); var dict = new torch.nn.ModuleDict( ("a", a), ("b", b) ); Assert.AreEqual(a.repr, dict["a"].repr); Assert.AreNotEqual(a.repr, dict["b"].repr); Assert.AreEqual(b.repr, dict["b"].repr); Assert.Throws <PythonException>(() => { var x = dict["nothing"]; }); var items = dict.items().ToArray(); Assert.AreEqual(2, items.Length); Assert.AreEqual("a", items[0].Item1); Assert.AreEqual("b", items[1].Item1); Assert.AreEqual("Conv1d(2, 2, kernel_size=(5,), stride=(1,))", items[0].Item2.repr); Assert.AreEqual("Conv1d(1, 1, kernel_size=(3,), stride=(1,))", items[1].Item2.repr); Assert.AreEqual("a, b", string.Join(", ", dict.keys())); Assert.AreEqual("Conv1d(2, 2, kernel_size=(5,), stride=(1,))|Conv1d(1, 1, kernel_size=(3,), stride=(1,))", string.Join("|", dict.values().Select(v => v.repr))); Assert.AreEqual("a, b", string.Join(", ", dict.Select(x => x.Item1))); Assert.AreEqual("Conv1d(2, 2, kernel_size=(5,), stride=(1,))|Conv1d(1, 1, kernel_size=(3,), stride=(1,))", string.Join("|", dict.Select(x => x.Item2.repr))); Assert.AreEqual("Conv1d(2, 2, kernel_size=(5,), stride=(1,))", dict.pop("a").repr); Assert.AreEqual(1, dict.items().Count()); dict.clear(); Assert.AreEqual(0, dict.items().Count()); dict.update(("a", a), ("b", b)); Assert.AreEqual(2, dict.items().Count()); var a1 = new torch.nn.Conv1d(dict["a"]); Assert.AreEqual(a.repr, a1.repr); Assert.AreEqual(2, dict.len()); }