コード例 #1
0
                /// <summary>
                /// Moves the parameters and buffers.
                /// </summary>
                /// <param name="deviceType">The device type, e.g. 'CPU' or 'CUDA'.</param>
                /// <param name="deviceIndex">The optional device index.</param>
                /// <returns></returns>
                public virtual Module to(DeviceType deviceType, int deviceIndex = -1)
                {
                    if (deviceType != DeviceType.CUDA)
                    {
                        deviceIndex = -1;
                    }

                    if (deviceType != _deviceType || deviceIndex != _deviceIndex)
                    {
                        torch.InitializeDeviceType(deviceType);
                        THSNN_Module_to_device(handle, (int)deviceType, deviceIndex);
                        torch.CheckForErrors();

                        foreach (var(_, sm) in named_children())
                        {
                            sm.to(deviceType, deviceIndex);
                        }

                        foreach (var field in this.GetType().GetFields(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance))
                        {
                            var name  = field.Name;
                            var value = field.GetValue(this);

                            Tensor            tensor = value as Tensor;
                            Modules.Parameter param  = value as Modules.Parameter;

                            if (param is not null)    // This test must come before the Tensor test
                            {
                                if (deviceType != param.device_type || deviceIndex != param.device_index)
                                {
                                    var p = new Modules.Parameter(param.to(deviceType, deviceIndex), param.requires_grad);
                                    field.SetValue(this, p);
                                    ConditionallyRegisterParameter(name, p);
                                }
                            }
                            else if (tensor is not null)
                            {
                                if (deviceType != tensor.device_type || deviceIndex != tensor.device_index)
                                {
                                    var t = tensor.to(deviceType, deviceIndex);
                                    field.SetValue(this, t);
                                    ConditionallyRegisterBuffer(name, t);
                                }
                            }
                        }

                        _deviceType  = deviceType;
                        _deviceIndex = deviceIndex;
                    }

                    Debug.Assert(_deviceType == DeviceType.CUDA || _deviceIndex == -1);

                    return(this);
                }
コード例 #2
0
ファイル: Module.cs プロジェクト: xamarin/TorchSharp
                /// <summary>
                /// Convert the parameters and buffers.
                /// </summary>
                /// <returns></returns>
                public virtual Module to(ScalarType dtype)
                {
                    THSNN_Module_to_dtype(handle, (sbyte)dtype);
                    torch.CheckForErrors();

                    foreach (var(_, sm) in named_children())
                    {
                        sm.to(dtype);
                    }

                    foreach (var field in this.GetType().GetFields(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance))
                    {
                        var name  = field.Name;
                        var value = field.GetValue(this);

                        Tensor            tensor = value as Tensor;
                        Modules.Parameter param  = value as Modules.Parameter;

                        if (param is not null)    // This test must come before the Tensor test
                        {
                            if (dtype != param.dtype)
                            {
                                var t = param.to(dtype);
                                t.retain_grad();
                                var p = new Modules.Parameter(t, param.requires_grad);
                                field.SetValue(this, p);
                                ConditionallyRegisterParameter(name, p);
                            }
                        }
                        else if (tensor is not null)
                        {
                            if (dtype != tensor.dtype)
                            {
                                var t = tensor.to(dtype);
                                field.SetValue(this, t);
                                ConditionallyRegisterBuffer(name, t);
                            }
                        }
                    }

                    return(this);
                }