예제 #1
0
        /**
         * Calculate the active cells, using the current active columns and dendrite
         * segments. Grow and reinforce synapses.
         *
         * <pre>
         * Pseudocode:
         *   for each column
         *     if column is active and has active distal dendrite segments
         *       call activatePredictedColumn
         *     if column is active and doesn't have active distal dendrite segments
         *       call burstColumn
         *     if column is inactive and has matching distal dendrite segments
         *       call punishPredictedColumn
         *
         * </pre>
         *
         * @param conn
         * @param activeColumnIndices
         * @param learn
         */
        public void ActivateCells(Connections conn, ComputeCycle cycle, int[] activeColumnIndices, bool learn)
        {
            ColumnData columnData = new ColumnData();

            HashSet <Cell> prevActiveCells = conn.GetActiveCells();
            HashSet <Cell> prevWinnerCells = conn.GetWinnerCells();

            List <Column> activeColumns = activeColumnIndices
                                          .OrderBy(i => i)
                                          .Select(i => conn.GetColumn(i))
                                          .ToList();

            Func <Column, Column>         identity = c => c;
            Func <DistalDendrite, Column> segToCol = segment => segment.GetParentCell().GetColumn();

            //@SuppressWarnings({ "rawtypes" })
            GroupBy2 <Column> grouper = GroupBy2 <Column> .Of(
                new Tuple <List <object>, Func <object, Column> >(activeColumns.Cast <object>().ToList(), x => identity((Column)x)),
                new Tuple <List <object>, Func <object, Column> >(new List <DistalDendrite>(conn.GetActiveSegments()).Cast <object>().ToList(), x => segToCol((DistalDendrite)x)),
                new Tuple <List <object>, Func <object, Column> >(new List <DistalDendrite>(conn.GetMatchingSegments()).Cast <object>().ToList(), x => segToCol((DistalDendrite)x)));

            double permanenceIncrement = conn.GetPermanenceIncrement();
            double permanenceDecrement = conn.GetPermanenceDecrement();

            foreach (Tuple t in grouper)
            {
                columnData = columnData.Set(t);

                if (columnData.IsNotNone(ACTIVE_COLUMNS))
                {
                    if (columnData.ActiveSegments().Any())
                    {
                        List <Cell> cellsToAdd = ActivatePredictedColumn(conn, columnData.ActiveSegments(),
                                                                         columnData.MatchingSegments(), prevActiveCells, prevWinnerCells,
                                                                         permanenceIncrement, permanenceDecrement, learn);

                        cycle.ActiveCells().UnionWith(cellsToAdd);
                        cycle.WinnerCells().UnionWith(cellsToAdd);
                    }
                    else
                    {
                        Tuple cellsXwinnerCell = BurstColumn(conn, columnData.Column(), columnData.MatchingSegments(),
                                                             prevActiveCells, prevWinnerCells, permanenceIncrement, permanenceDecrement, conn.GetRandom(),
                                                             learn);

                        cycle.ActiveCells().UnionWith((IEnumerable <Cell>)cellsXwinnerCell.Get(0));
                        cycle.WinnerCells().Add((Cell)cellsXwinnerCell.Get(1));
                    }
                }
                else
                {
                    if (learn)
                    {
                        PunishPredictedColumn(conn, columnData.ActiveSegments(), columnData.MatchingSegments(),
                                              prevActiveCells, prevWinnerCells, conn.GetPredictedSegmentDecrement());
                    }
                }
            }
        }
예제 #2
0
        public void TestActivateCorrectlyPredictiveCells()
        {
            TemporalMemory tm = new TemporalMemory();
            Connections    cn = new Connections();
            Parameters     p  = GetDefaultParameters();

            p.Apply(cn);
            TemporalMemory.Init(cn);

            int[]          previousActiveColumns = { 0 };
            int[]          activeColumns         = { 1 };
            Cell           cell4 = cn.GetCell(4);
            HashSet <Cell> expectedActiveCells = new HashSet <Cell> {
                cell4
            };                                                               //Stream.of(cell4).collect(Collectors.toSet());

            DistalDendrite activeSegment = cn.CreateSegment(cell4);

            cn.CreateSynapse(activeSegment, cn.GetCell(0), 0.5);
            cn.CreateSynapse(activeSegment, cn.GetCell(1), 0.5);
            cn.CreateSynapse(activeSegment, cn.GetCell(2), 0.5);
            cn.CreateSynapse(activeSegment, cn.GetCell(3), 0.5);

            ComputeCycle cc = tm.Compute(cn, previousActiveColumns, true);

            Assert.IsTrue(cc.PredictiveCells().SetEquals(expectedActiveCells));
            ComputeCycle cc2 = tm.Compute(cn, activeColumns, true);

            Assert.IsTrue(cc2.ActiveCells().SetEquals(expectedActiveCells));
        }
예제 #3
0
        public void testZeroActiveColumns()
        {
            TemporalMemory tm = new TemporalMemory();
            Connections    cn = new Connections();
            Parameters     p  = GetDefaultParameters();

            p.Apply(cn);
            TemporalMemory.Init(cn);

            int[] previousActiveColumns = { 0 };
            Cell  cell4 = cn.GetCell(4);

            DistalDendrite activeSegment = cn.CreateSegment(cell4);

            cn.CreateSynapse(activeSegment, cn.GetCell(0), 0.5);
            cn.CreateSynapse(activeSegment, cn.GetCell(1), 0.5);
            cn.CreateSynapse(activeSegment, cn.GetCell(2), 0.5);
            cn.CreateSynapse(activeSegment, cn.GetCell(3), 0.5);

            ComputeCycle cc = tm.Compute(cn, previousActiveColumns, true);

            Assert.IsFalse(cc.ActiveCells().Count == 0);
            Assert.IsFalse(cc.WinnerCells().Count == 0);
            Assert.IsFalse(cc.PredictiveCells().Count == 0);

            int[]        zeroColumns = new int[0];
            ComputeCycle cc2         = tm.Compute(cn, zeroColumns, true);

            Assert.IsTrue(cc2.ActiveCells().Count == 0);
            Assert.IsTrue(cc2.WinnerCells().Count == 0);
            Assert.IsTrue(cc2.PredictiveCells().Count == 0);
        }
예제 #4
0
        public void testBurstUnpredictedColumns()
        {
            TemporalMemory tm = new TemporalMemory();
            Connections    cn = new Connections();
            Parameters     p  = GetDefaultParameters();

            p.Apply(cn);
            TemporalMemory.Init(cn);

            int[]          activeColumns = { 0 };
            HashSet <Cell> burstingCells = cn.GetCellSet(new int[] { 0, 1, 2, 3 });

            ComputeCycle cc = tm.Compute(cn, activeColumns, true);

            Assert.IsTrue(cc.ActiveCells().SetEquals(burstingCells));
        }
예제 #5
0
        public void testAddSegmentToCellWithFewestSegments()
        {
            bool grewOnCell1 = false;
            bool grewOnCell2 = false;

            for (int seed = 0; seed < 100; seed++)
            {
                TemporalMemory tm = new TemporalMemory();
                Connections    cn = new Connections();
                Parameters     p  = GetDefaultParameters(null, Parameters.KEY.MAX_NEW_SYNAPSE_COUNT, 4);
                p = GetDefaultParameters(p, Parameters.KEY.PREDICTED_SEGMENT_DECREMENT, 0.02);
                p = GetDefaultParameters(p, Parameters.KEY.SEED, seed);
                p.SetParameterByKey(Parameters.KEY.RANDOM, new XorshiftRandom(seed));
                p.Apply(cn);
                TemporalMemory.Init(cn);

                int[]          prevActiveColumns = { 1, 2, 3, 4 };
                Cell[]         prevActiveCells   = { cn.GetCell(4), cn.GetCell(5), cn.GetCell(6), cn.GetCell(7) };
                int[]          activeColumns     = { 0 };
                Cell[]         nonMatchingCells  = { cn.GetCell(0), cn.GetCell(3) };
                HashSet <Cell> activeCells       = cn.GetCellSet(new int[] { 0, 1, 2, 3 });

                DistalDendrite segment1 = cn.CreateSegment(nonMatchingCells[0]);
                cn.CreateSynapse(segment1, prevActiveCells[0], 0.5);
                DistalDendrite segment2 = cn.CreateSegment(nonMatchingCells[1]);
                cn.CreateSynapse(segment2, prevActiveCells[1], 0.5);

                tm.Compute(cn, prevActiveColumns, true);
                ComputeCycle cc = tm.Compute(cn, activeColumns, true);

                Assert.IsTrue(cc.ActiveCells().SetEquals(activeCells));

                Assert.AreEqual(3, cn.GetNumSegments());
                Assert.AreEqual(1, cn.GetNumSegments(cn.GetCell(0)));
                Assert.AreEqual(1, cn.GetNumSegments(cn.GetCell(3)));
                Assert.AreEqual(1, cn.GetNumSynapses(segment1));
                Assert.AreEqual(1, cn.GetNumSynapses(segment2));

                List <DistalDendrite> segments = new List <DistalDendrite>(cn.GetSegments(cn.GetCell(1)));
                if (segments.Count == 0)
                {
                    List <DistalDendrite> segments2 = cn.GetSegments(cn.GetCell(2));
                    Assert.IsFalse(segments2.Count == 0);
                    grewOnCell2 = true;
                    segments.AddRange(segments2);
                }
                else
                {
                    grewOnCell1 = true;
                }

                Assert.AreEqual(1, segments.Count);
                List <Synapse> synapses = segments[0].GetAllSynapses(cn);
                Assert.AreEqual(4, synapses.Count);

                HashSet <Column> columnCheckList = cn.GetColumnSet(prevActiveColumns);

                foreach (Synapse synapse in synapses)
                {
                    Assert.AreEqual(0.2, synapse.GetPermanence(), 0.01);

                    Column column = synapse.GetPresynapticCell().GetColumn();
                    Assert.IsTrue(columnCheckList.Contains(column));
                    columnCheckList.Remove(column);
                }

                Assert.AreEqual(0, columnCheckList.Count);
            }

            Assert.IsTrue(grewOnCell1);
            Assert.IsTrue(grewOnCell2);
        }