Exemplo n.º 1
0
 public Engine()
 {
     this.backend     = new Backend_CPU();
     this.activeScope = new ScopeState()
     {
         name = "default_scope", track = new List <Tensor>()
     };
     this.scopeStack.Push(this.activeScope);
 }
Exemplo n.º 2
0
        public void endScope(List <Tensor> result, bool gradientsMode = false)
        {
            if (gradientsMode)
            {
                this.gradientScopeCount--;
                if (this.gradientScopeCount == 0)
                {
                    this.activeTape = null;
                }
            }

            var tensorsToKeep          = new List <int>(this.keepTensors.ToArray());
            var tensorsToTrackInParent = new List <Tensor>(result.ToArray());

            tensorsToKeep.AddRange(tensorsToTrackInParent.Select(p => p.id).ToArray());
            for (var i = 0; i < this.activeScope.track.Count; i++)
            {
                var tensor = this.activeScope.track[i];
                if (tensorsToKeep.Contains(tensor.id))
                {
                    continue;
                }

                if (this.activeTape != null)
                {
                    tensorsToTrackInParent.Add(tensor);
                }
                else
                {
                    tensor.dispose();
                }
            }

            var oldScope = this.scopeStack.Pop();


            this.activeScope = this.scopeStack.Count == 0 ?
                               new ScopeState()
            {
                track = new List <Tensor>()
            } :
            this.scopeStack.FirstOrDefault();

            foreach (var tensor in tensorsToTrackInParent)
            {
                if (!this.keepTensors.Contains(tensor.id) &&
                    oldScope.track.Where(p => p.id == tensor.id).Count() > 0

                    )
                {
                    this.track(tensor);
                }
            }
        }
Exemplo n.º 3
0
        public void startScope(string name = null, bool gradientsMode = false)
        {
            if (gradientsMode && this.gradientScopeCount == 0)
            {
                this.activeTape = new List <TapeNode>();
            }
            if (gradientsMode)
            {
                this.gradientScopeCount++;
            }
            ScopeState scopeInfo = new ScopeState();

            scopeInfo.track = new List <Tensor>();
            if (name != null)
            {
                scopeInfo.name = name;
            }
            this.scopeStack.Push(scopeInfo);
            this.activeScope = scopeInfo;
        }