Skip to main content

Pytorch: detach() and detach_()

In short:
generally used: var.detach().float().cpu()
detach() doesn't affect the original graph, it will create a copy of the variable with requires_grad = false. [There is much more as well]
detach_() is the in-place version of detach()


More on detach():
.detach() -> doesn't change requires_grad property of where it is applied but for the ones that follow.

Think of detach as the breaking point between two graphs
https://github.com/pytorch/examples/issues/116
http://www.bnikolic.co.uk/blog/pytorch-detach.html

https://github.com/szagoruyko/pytorchviz/blob/master/examples.ipynb
http://ruotianluo.github.io/2017/02/11/pytorch-attempt/

https://blog.csdn.net/u012436149/article/details/76714349

Detach
This method is described in the official documentation.

Returns a new Variable separated from the current image.
Returned Variable will never need a gradient
If detached Variable volatile=True, the detached volatile is also True.
There is also a caveat: the returned Variable and the detached Variable point to the same tensor
---------------------
https://blog.csdn.net/u012436149/article/details/76714349


detach_() will change the original variable, detach() wont.
detach_() is the in-place version of detach()


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import numpy as np

seed = 0
torch.manual_seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)


def get_mse_(pred, trut): #image range 0-1
    diff = pred.detach().cpu()-trut.detach().cpu()
    mse = np.mean(np.square(diff.numpy()),axis=(-1)) #calculating mse on each image : op = batch_size x 1
    return np.mean(mse)

def get_mse(diff): #image range 0-1
    mse = np.mean(np.square(diff.numpy()),axis=(-1)) #calculating mse on each image : op = batch_size x 1
    return np.mean(mse)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # assert device == "cuda:0"

class LinReg(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.beta = nn.Linear(input_dim, 1)
        
    def forward(self, X):
        return self.beta(X)

model = LinReg(1).to(device)
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size(), "\t", model.state_dict()[param_tensor])

loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

X = torch.tensor([[1],[2],[3],[4],[5]]).float().to(device)
y = torch.tensor([[1],[4],[6],[8],[10]]).float().to(device)
y_ = model(X)

print(y_, y_.requires_grad)
print(y, y.requires_grad)

diff = y_ - y
print(diff, diff.requires_grad)


diff2 = y_.detach() - y.detach()  #there is a detach_ as well which will change the original variable or tensor.
print(diff2, diff2.requires_grad)

print(get_mse(diff))

print(get_mse(diff2))

print(get_mse(diff2.cpu()))

print('Weight: {}\t Bias: {}\t MSE: {}'.format(model.beta.weight.data,model.beta.bias.data, get_mse_(y_,y)))

print('Weight: {}\t Bias: {}'.format(model.beta.weight.data,model.beta.bias.data))
for _ in range(20):
    model.train()
    optimizer.zero_grad()
    y_ = model(X)
    loss = loss_fn(y_, y)
    loss.backward()
    optimizer.step()
    print('Weight: {}\t Bias: {}\t MSE: {}'.format(model.beta.weight.data,model.beta.bias.data, get_mse_(y_,y)))