Introduction to PyTorch Detach
PyTorch Detach creates a sensor where the storage is shared with another tensor with no grad involved, and thus a new tensor is returned which has no attachments with the current gradients. A gradient is not required here, and hence the result will not have any forward gradients or any type of gradients as such. The output has no attachment with the computational graph, and hence the result has no gradient.
PyTorch Detach Overview
- Variable is detached from the gradient computational graph where less number of variables and functions are used. Mostly it is used when loss and accuracy has to be displayed once the epoch ends in the neural network. Here, only consumed resources are used, and the gradients no longer affect the results. However, all the intermediary results are stored, and hence more memory is required here. All the operations within the statement of detach are affected, and hence it will not go to the next step of the process continuously unless detach command is removed.
- When a tensor has to be removed from the computational graph, detach can be used. PyTorch helps in automatic differentiation by tracking all the operations to compute the gradient for everything. Thus, a graph is created for all the operations, which will require more memory. Now, if we use detach, the tensor view will be differentiated from the following methods, and all the tracking operations will be stopped. If we need to track furthermore, we have to start a new class or method.
- We can also use detach().numpy() where the computational graph is broken directly, and thus the gradients can be calculated using PyTorch in the same program. However, here the tensors are converted to numpy arrays, and hence we will lose tracking of the gradients completely with the code.
How does Detach Work?
Let us see examples, where detach is used and not used.
Here b equals a^4, and c equals a^6. Hence, I equal a^4 + a^6. The derivative will be 4a^3 + 6a^5. The gradient of a will be 4*2^3 + 6*2^5 = 224. a.grad produces the vector with 20 elements where each element has a value of 224.
Another example where detach is used.
Here c is not calculated while calculating the gradient as it is detached from the previous graph. Thus, the derivative value will be 3a^2 which is 12. A vector is produced by a.grad with 20 elements where all the elements have a value of 12.
m = torch.arange(5., requires_grad=True)
n = m**2
o = m.detach()
An error will be thrown here as the data is not correct. If we remove the o.zero command, then we will get the gradient value. Detach method does not create the tensor directly, but when the tensor is modified in the code, a tensor is updated in all streams of detach commands. Copies are not created using detach, but gradients are blocked to share the data without gradients. Detach is useful when the tensor values are not needed in the computational graph.
PyTorch Detach Method
It is important for PyTorch to keep track of all the information and operations related to tensors so that it will help to compute the gradients. These will be in the form of graphs where detach method helps to create a new view of the same where gradients are not needed. All the other tracking operations will be removed from the graph, and hence the graphs involving the results will not be recorded. Instead, we can use torchviz package to see how the gradient is computed with the tensor given.
The following operations cannot be tracked here, and the program will look like this.
Here c**6 will no longer be tracked, which is how the to detach method works in PyTorch.
The storage will be the same as the previous gradient. All the modifications can be seen in the tensor so that the original tensor can also be updated. Forward mode AD gradients will not be present in the system, and the results also will never show the forward gradients.
Example of PyTorch Detach
Given below is the example mentioned:
print("it is the same storage space")
print("it is different storage space")
p = torch.ones((4,5), requires_grad=True)
q = p
r = p.data
s = p.detach()
t = p.data.clone()
u = p.clone()
v = p.detach().clone()
w = torch.empty_like(p).copy_(p)
x = torch.tensor(p)
If we need to copy constructs from the tensor, we can use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True). Torch.sensor(sourceTensor) will not always work for the gradient problems.
The output will show whether it is the same or different storage. PyTorch has nearly 100 constructors, and hence we can add in anyways to the code. If we use copy(), all the related information will be copied along with the code, and hence it is better to use clone and detach in the code like this.
b = a.clone().detach()
setup=lambda l: torch.randn(l),
lambda x: x.new_tensor(x),
lambda x: x.clone().detach(),
lambda x: torch.empty_like(x).copy_(x),
lambda x: torch.tensor(x),
lambda x: x.detach().clone(),
labels=["new_tensor()", "clone().detach()", "empty_like().copy()", "tensor()", "detach().clone()"],
l_range=[3 ** i for i in range(30)],
title='Comparison for timing related to PyTorch tensor,
We cannot use the clone method alone as the gradient will be propagated to the cloned tensor, and thus original tensor also will be affected. This leads to errors that cannot be figured out easily. Hence detach() method can be used here so that graph is disconnected from the tensor, and hence errors will not occur.
If we want to copy the tensor first and then detach it from the computational graph, a clone should be used along with detach. The codes for detach are not always complicated, and hence we should be clear about the process being done for detaching the computational graph from the entire process.
This is a guide to PyTorch Detach. Here we discuss the introduction, overview, how does detach works? Method and example, respectively. You may also have a look at the following articles to learn more –