Introduction to PyTorch Flatten
PyTorch Flatten is used to reshape any tensor with different dimensions to a single dimension so that we can do further operations on the same input data. The shape of the tensor will be the same as that of the number of elements in the tensor. Here the main purpose is to remove all dimensions and to keep a single dimension on the tensor. Mostly this is used to reshape any tensor in the network.
What is PyTorch Flatten?
Tensor t is taken as an argument in flatten function and since reshape has to be done as the next part, -1 is passed as the second argument. The value should be based on the number of elements inside the tensor, and it equals the product of elements in the tensor. For example, if there are six elements in the tensor, then the shape of the tensor will be 6. It also helps us to flatten the values when we pass the values from the convolutional layer to the linear layer. If needed, we can flatten a few elements in the tensor by giving the parameters as start_dim and end_dim.
PyTorch Flatten Function
Based on our requirement, we can make the function return the original value or copy the input data as present in the code. If the flattening condition is not given, the original values are returned, and if flattening is set, then the values are returned after flattening the input. Now, if the input values cannot be changed into flattening conditions, then the input values are copied and returned as output. Even if we are flattening the zero-dimensional tensor, a one-dimensional tensor is returned as output. Flatten does not copy the values as input, but it wraps the view function and uses reshape underneath the function.
There are three methods in flattening the tensors using PyTorch. The first method is the oops method where torch.tensor.flatten is used to apply directly to the tensor. Here the code is written as x.flatten(). Another method is the functional method, where the code is written in the format of the torch.flatten. Hence the code to flatten x value will be torch.flatten(x). We have another method where we convert the code in module format. It will be in nn.module format i.e. nn.flatten(). This code is mostly used in the definition of the model.
PyTorch Flatten Examples
def flatten_ex(self, a): a= self.model(a) a = torch.flatten(a, start_dim=1) # flattens the entire line of code a = self.full_conn1(a) a = self.norm1(a) a = F.relu(a) a = F.dropout(a) a = self.full_conn2(a) a = F.relu(a) a = F.dropout(a) a = self.full_conn3(a) return a def tensor_serializer(worker: AbstractWorker, tensor: torch.Tensor) -> TensorDataPB: """Strategy to serialize a tensor using serialization""" datatype = TORCH_DTYPE_STR[tensor.dtype] serialize_tensor = TensorDataPB() if tensor.is_serialized: serialize_tensor.is_quantized = True serialize_tensor.scale = tensor.q_scale() serailize_tensor.zero_point = tensor.q_zero_point() data = torch.flatten(tensor).int_repr().tolist() else: data = torch.flatten(tensor).tolist() serialize_tensor.dtype = datatype serialize_tensor.shape.dims.extend(tensor.size()) getattr(serialize_tensor, "contents_" + datatype).extend(data) return serialize_tensor def forward_tensor(self, p): p = p.reshape(140, 140, 2) p = torch.narrow(p, dim=3, start=1, length=2) p = p.reshape(2, 2, 140, 140) p = F.avg_pool2d(p, 5, stride=5) p = p / 255 p = (p - MEAN) / STANDARD_DEVIATION p = self.conv1(p) p = F.relu(p) p = self.conv2(p) p = F.max_pool2d(p, 2) p = self.dropout1(p) p = torch.flatten(p, 1) p = self.fc1(p) p = F.relu(p) p = self.dropout2(p) p = self.fc2(p) result = F.softmax(p, dim=2) return result def fun_n0(self, a, b, N, dtype): n0x, n0y = torch.meshgrid( torch.arange(1, a * self.stride + 1, self.stride), torch.arange(1, b * self.stride + 1, self.stride)) n0x = torch.flatten(n0x).view(1, 1, a, b).repeat(1, N, 1, 1) n0y = torch.flatten(n0y).view(1, 1, a, b).repeat(1, N, 1, 1) n0 = torch.cat([n0x, n0y], 1).type(dtype) return n0
PyTorch Flatten parameters
We have only three parameters for PyTorch flatten. They are input, start_dim, and end_dim.
Input value ( this is a tensor) – the input tensor which is mostly values where we need to flatten it to one dimension.
Start_dim (integer value) – the first dimension in the code to flatten the values
End_dim (integer value) – the last dimension in the code to flatten the values
Let us see an example based on this parametric proposition.
s = torch.tensor([[[29, 28], [27, 26]], [[25, 24], [23, 22]]]) torch.flatten(s) tensor([29, 28, 27, 26, 25, 24, 23, 22]) torch.flatten(s, start_dim=1) h = torch.randn(16, 2, 15, 54) f = nn.Sequential( nn.Conv2d(15, 2, 16, 54, 2), nn.Flatten() ) result = f(h) result.size()
When the module is replicated inside GPU, it is important to use flattened parameters so that the values inside the tensor reduce the dimension and manage the output tensors. But when we use flattened parameters in forwarding pass, more storage is used, and huge compute time is taken for calculations of each value.
PyTorch Flatten Transcript
The first step is to import PyTorch from the libraries.
The pyTorch version can be checked from the below code.
We can create a sample of the PyTorch tensor to see the data structure.
pt_example_tensor1 = torch.Tensor(
[ 3, 1, 4, 5],
[ 2, 8, 6, 7] ] ,
[ 13, 10, 9, 15],
[12, 11, 16, 14] ] ,
[23, 17, 18, 20],
[19, 21, 24, 23] ] ])
There are two rows and four columns in each matrix. Now let us print the same to know the dimensions of the same.
The result is four matrices with two rows and four columns, and hence the flattened result will be the result of 24 values.
Let us use the flatten code in the system.
Now, we can print the result to see whether the flattening in PyTorch worked.
So the result is one dimensional tensor with 24 values where it is reduced from 3 matrices to a single dimension.
If our goal is to flatten the tensor and not necessarily flatten the function itself, then we can use view or reshape and a negative 1 on the code to get the same result. However, when we are sending values to a neural network, we must flatten the values and send them to the network to get the required results. The number of nodes also should be mentioned here.
This is a guide to PyTorch Flatten. Here we discuss What is PyTorch Flatten along with the three parameters for PyTorch flatten. You may also have a look at the following articles to learn more –