/** * Calculate the active cells, using the current active columns and dendrite * segments. Grow and reinforce synapses. * * <pre> * Pseudocode: * for each column * if column is active and has active distal dendrite segments * call activatePredictedColumn * if column is active and doesn't have active distal dendrite segments * call burstColumn * if column is inactive and has matching distal dendrite segments * call punishPredictedColumn * * </pre> * * @param conn * @param activeColumnIndices * @param learn */ public void ActivateCells(Connections conn, ComputeCycle cycle, int[] activeColumnIndices, bool learn) { ColumnData columnData = new ColumnData(); HashSet <Cell> prevActiveCells = conn.GetActiveCells(); HashSet <Cell> prevWinnerCells = conn.GetWinnerCells(); List <Column> activeColumns = activeColumnIndices .OrderBy(i => i) .Select(i => conn.GetColumn(i)) .ToList(); Func <Column, Column> identity = c => c; Func <DistalDendrite, Column> segToCol = segment => segment.GetParentCell().GetColumn(); //@SuppressWarnings({ "rawtypes" }) GroupBy2 <Column> grouper = GroupBy2 <Column> .Of( new Tuple <List <object>, Func <object, Column> >(activeColumns.Cast <object>().ToList(), x => identity((Column)x)), new Tuple <List <object>, Func <object, Column> >(new List <DistalDendrite>(conn.GetActiveSegments()).Cast <object>().ToList(), x => segToCol((DistalDendrite)x)), new Tuple <List <object>, Func <object, Column> >(new List <DistalDendrite>(conn.GetMatchingSegments()).Cast <object>().ToList(), x => segToCol((DistalDendrite)x))); double permanenceIncrement = conn.GetPermanenceIncrement(); double permanenceDecrement = conn.GetPermanenceDecrement(); foreach (Tuple t in grouper) { columnData = columnData.Set(t); if (columnData.IsNotNone(ACTIVE_COLUMNS)) { if (columnData.ActiveSegments().Any()) { List <Cell> cellsToAdd = ActivatePredictedColumn(conn, columnData.ActiveSegments(), columnData.MatchingSegments(), prevActiveCells, prevWinnerCells, permanenceIncrement, permanenceDecrement, learn); cycle.ActiveCells().UnionWith(cellsToAdd); cycle.WinnerCells().UnionWith(cellsToAdd); } else { Tuple cellsXwinnerCell = BurstColumn(conn, columnData.Column(), columnData.MatchingSegments(), prevActiveCells, prevWinnerCells, permanenceIncrement, permanenceDecrement, conn.GetRandom(), learn); cycle.ActiveCells().UnionWith((IEnumerable <Cell>)cellsXwinnerCell.Get(0)); cycle.WinnerCells().Add((Cell)cellsXwinnerCell.Get(1)); } } else { if (learn) { PunishPredictedColumn(conn, columnData.ActiveSegments(), columnData.MatchingSegments(), prevActiveCells, prevWinnerCells, conn.GetPredictedSegmentDecrement()); } } } }
public void TestActivateCorrectlyPredictiveCells() { TemporalMemory tm = new TemporalMemory(); Connections cn = new Connections(); Parameters p = GetDefaultParameters(); p.Apply(cn); TemporalMemory.Init(cn); int[] previousActiveColumns = { 0 }; int[] activeColumns = { 1 }; Cell cell4 = cn.GetCell(4); HashSet <Cell> expectedActiveCells = new HashSet <Cell> { cell4 }; //Stream.of(cell4).collect(Collectors.toSet()); DistalDendrite activeSegment = cn.CreateSegment(cell4); cn.CreateSynapse(activeSegment, cn.GetCell(0), 0.5); cn.CreateSynapse(activeSegment, cn.GetCell(1), 0.5); cn.CreateSynapse(activeSegment, cn.GetCell(2), 0.5); cn.CreateSynapse(activeSegment, cn.GetCell(3), 0.5); ComputeCycle cc = tm.Compute(cn, previousActiveColumns, true); Assert.IsTrue(cc.PredictiveCells().SetEquals(expectedActiveCells)); ComputeCycle cc2 = tm.Compute(cn, activeColumns, true); Assert.IsTrue(cc2.ActiveCells().SetEquals(expectedActiveCells)); }
public void testZeroActiveColumns() { TemporalMemory tm = new TemporalMemory(); Connections cn = new Connections(); Parameters p = GetDefaultParameters(); p.Apply(cn); TemporalMemory.Init(cn); int[] previousActiveColumns = { 0 }; Cell cell4 = cn.GetCell(4); DistalDendrite activeSegment = cn.CreateSegment(cell4); cn.CreateSynapse(activeSegment, cn.GetCell(0), 0.5); cn.CreateSynapse(activeSegment, cn.GetCell(1), 0.5); cn.CreateSynapse(activeSegment, cn.GetCell(2), 0.5); cn.CreateSynapse(activeSegment, cn.GetCell(3), 0.5); ComputeCycle cc = tm.Compute(cn, previousActiveColumns, true); Assert.IsFalse(cc.ActiveCells().Count == 0); Assert.IsFalse(cc.WinnerCells().Count == 0); Assert.IsFalse(cc.PredictiveCells().Count == 0); int[] zeroColumns = new int[0]; ComputeCycle cc2 = tm.Compute(cn, zeroColumns, true); Assert.IsTrue(cc2.ActiveCells().Count == 0); Assert.IsTrue(cc2.WinnerCells().Count == 0); Assert.IsTrue(cc2.PredictiveCells().Count == 0); }
public void testBurstUnpredictedColumns() { TemporalMemory tm = new TemporalMemory(); Connections cn = new Connections(); Parameters p = GetDefaultParameters(); p.Apply(cn); TemporalMemory.Init(cn); int[] activeColumns = { 0 }; HashSet <Cell> burstingCells = cn.GetCellSet(new int[] { 0, 1, 2, 3 }); ComputeCycle cc = tm.Compute(cn, activeColumns, true); Assert.IsTrue(cc.ActiveCells().SetEquals(burstingCells)); }
public void testAddSegmentToCellWithFewestSegments() { bool grewOnCell1 = false; bool grewOnCell2 = false; for (int seed = 0; seed < 100; seed++) { TemporalMemory tm = new TemporalMemory(); Connections cn = new Connections(); Parameters p = GetDefaultParameters(null, Parameters.KEY.MAX_NEW_SYNAPSE_COUNT, 4); p = GetDefaultParameters(p, Parameters.KEY.PREDICTED_SEGMENT_DECREMENT, 0.02); p = GetDefaultParameters(p, Parameters.KEY.SEED, seed); p.SetParameterByKey(Parameters.KEY.RANDOM, new XorshiftRandom(seed)); p.Apply(cn); TemporalMemory.Init(cn); int[] prevActiveColumns = { 1, 2, 3, 4 }; Cell[] prevActiveCells = { cn.GetCell(4), cn.GetCell(5), cn.GetCell(6), cn.GetCell(7) }; int[] activeColumns = { 0 }; Cell[] nonMatchingCells = { cn.GetCell(0), cn.GetCell(3) }; HashSet <Cell> activeCells = cn.GetCellSet(new int[] { 0, 1, 2, 3 }); DistalDendrite segment1 = cn.CreateSegment(nonMatchingCells[0]); cn.CreateSynapse(segment1, prevActiveCells[0], 0.5); DistalDendrite segment2 = cn.CreateSegment(nonMatchingCells[1]); cn.CreateSynapse(segment2, prevActiveCells[1], 0.5); tm.Compute(cn, prevActiveColumns, true); ComputeCycle cc = tm.Compute(cn, activeColumns, true); Assert.IsTrue(cc.ActiveCells().SetEquals(activeCells)); Assert.AreEqual(3, cn.GetNumSegments()); Assert.AreEqual(1, cn.GetNumSegments(cn.GetCell(0))); Assert.AreEqual(1, cn.GetNumSegments(cn.GetCell(3))); Assert.AreEqual(1, cn.GetNumSynapses(segment1)); Assert.AreEqual(1, cn.GetNumSynapses(segment2)); List <DistalDendrite> segments = new List <DistalDendrite>(cn.GetSegments(cn.GetCell(1))); if (segments.Count == 0) { List <DistalDendrite> segments2 = cn.GetSegments(cn.GetCell(2)); Assert.IsFalse(segments2.Count == 0); grewOnCell2 = true; segments.AddRange(segments2); } else { grewOnCell1 = true; } Assert.AreEqual(1, segments.Count); List <Synapse> synapses = segments[0].GetAllSynapses(cn); Assert.AreEqual(4, synapses.Count); HashSet <Column> columnCheckList = cn.GetColumnSet(prevActiveColumns); foreach (Synapse synapse in synapses) { Assert.AreEqual(0.2, synapse.GetPermanence(), 0.01); Column column = synapse.GetPresynapticCell().GetColumn(); Assert.IsTrue(columnCheckList.Contains(column)); columnCheckList.Remove(column); } Assert.AreEqual(0, columnCheckList.Count); } Assert.IsTrue(grewOnCell1); Assert.IsTrue(grewOnCell2); }