/// <summary> /// Moves the parameters and buffers. /// </summary> /// <param name="deviceType">The device type, e.g. 'CPU' or 'CUDA'.</param> /// <param name="deviceIndex">The optional device index.</param> /// <returns></returns> public virtual Module to(DeviceType deviceType, int deviceIndex = -1) { if (deviceType != DeviceType.CUDA) { deviceIndex = -1; } if (deviceType != _deviceType || deviceIndex != _deviceIndex) { torch.InitializeDeviceType(deviceType); THSNN_Module_to_device(handle, (int)deviceType, deviceIndex); torch.CheckForErrors(); foreach (var(_, sm) in named_children()) { sm.to(deviceType, deviceIndex); } foreach (var field in this.GetType().GetFields(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance)) { var name = field.Name; var value = field.GetValue(this); Tensor tensor = value as Tensor; Modules.Parameter param = value as Modules.Parameter; if (param is not null) // This test must come before the Tensor test { if (deviceType != param.device_type || deviceIndex != param.device_index) { var p = new Modules.Parameter(param.to(deviceType, deviceIndex), param.requires_grad); field.SetValue(this, p); ConditionallyRegisterParameter(name, p); } } else if (tensor is not null) { if (deviceType != tensor.device_type || deviceIndex != tensor.device_index) { var t = tensor.to(deviceType, deviceIndex); field.SetValue(this, t); ConditionallyRegisterBuffer(name, t); } } } _deviceType = deviceType; _deviceIndex = deviceIndex; } Debug.Assert(_deviceType == DeviceType.CUDA || _deviceIndex == -1); return(this); }
/// <summary> /// Convert the parameters and buffers. /// </summary> /// <returns></returns> public virtual Module to(ScalarType dtype) { THSNN_Module_to_dtype(handle, (sbyte)dtype); torch.CheckForErrors(); foreach (var(_, sm) in named_children()) { sm.to(dtype); } foreach (var field in this.GetType().GetFields(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance)) { var name = field.Name; var value = field.GetValue(this); Tensor tensor = value as Tensor; Modules.Parameter param = value as Modules.Parameter; if (param is not null) // This test must come before the Tensor test { if (dtype != param.dtype) { var t = param.to(dtype); t.retain_grad(); var p = new Modules.Parameter(t, param.requires_grad); field.SetValue(this, p); ConditionallyRegisterParameter(name, p); } } else if (tensor is not null) { if (dtype != tensor.dtype) { var t = tensor.to(dtype); field.SetValue(this, t); ConditionallyRegisterBuffer(name, t); } } } return(this); }