public Engine() { this.backend = new Backend_CPU(); this.activeScope = new ScopeState() { name = "default_scope", track = new List <Tensor>() }; this.scopeStack.Push(this.activeScope); }
public void endScope(List <Tensor> result, bool gradientsMode = false) { if (gradientsMode) { this.gradientScopeCount--; if (this.gradientScopeCount == 0) { this.activeTape = null; } } var tensorsToKeep = new List <int>(this.keepTensors.ToArray()); var tensorsToTrackInParent = new List <Tensor>(result.ToArray()); tensorsToKeep.AddRange(tensorsToTrackInParent.Select(p => p.id).ToArray()); for (var i = 0; i < this.activeScope.track.Count; i++) { var tensor = this.activeScope.track[i]; if (tensorsToKeep.Contains(tensor.id)) { continue; } if (this.activeTape != null) { tensorsToTrackInParent.Add(tensor); } else { tensor.dispose(); } } var oldScope = this.scopeStack.Pop(); this.activeScope = this.scopeStack.Count == 0 ? new ScopeState() { track = new List <Tensor>() } : this.scopeStack.FirstOrDefault(); foreach (var tensor in tensorsToTrackInParent) { if (!this.keepTensors.Contains(tensor.id) && oldScope.track.Where(p => p.id == tensor.id).Count() > 0 ) { this.track(tensor); } } }
public void startScope(string name = null, bool gradientsMode = false) { if (gradientsMode && this.gradientScopeCount == 0) { this.activeTape = new List <TapeNode>(); } if (gradientsMode) { this.gradientScopeCount++; } ScopeState scopeInfo = new ScopeState(); scopeInfo.track = new List <Tensor>(); if (name != null) { scopeInfo.name = name; } this.scopeStack.Push(scopeInfo); this.activeScope = scopeInfo; }