EDUCBA Logo

EDUCBA

MENUMENU
  • Explore
    • EDUCBA Pro
    • PRO Bundles
    • Featured Skills
    • New & Trending
    • Fresh Entries
    • Finance
    • Data Science
    • Programming and Dev
    • Excel
    • Marketing
    • HR
    • PDP
    • VFX and Design
    • Project Management
    • Exam Prep
    • All Courses
  • Blog
  • Enterprise
  • Free Courses
  • Log in
  • Sign Up
Home Software Development Software Development Tutorials PyTorch Tutorial PyTorch Flatten
 

PyTorch Flatten

Updated April 6, 2023

PyTorch Flatten

 

 

Introduction to PyTorch Flatten

PyTorch Flatten is used to reshape any tensor with different dimensions to a single dimension so that we can do further operations on the same input data. The shape of the tensor will be the same as that of the number of elements in the tensor. Here the main purpose is to remove all dimensions and to keep a single dimension on the tensor. Mostly this is used to reshape any tensor in the network.

Watch our Demo Courses and Videos

Valuation, Hadoop, Excel, Mobile Apps, Web Development & many more.

What is PyTorch Flatten?

Tensor t is taken as an argument in flatten function and since reshape has to be done as the next part, -1 is passed as the second argument. The value should be based on the number of elements inside the tensor, and it equals the product of elements in the tensor. For example, if there are six elements in the tensor, then the shape of the tensor will be 6. It also helps us to flatten the values when we pass the values from the convolutional layer to the linear layer. If needed, we can flatten a few elements in the tensor by giving the parameters as start_dim and end_dim.

PyTorch Flatten Function

Based on our requirement, we can make the function return the original value or copy the input data as present in the code. If the flattening condition is not given, the original values are returned, and if flattening is set, then the values are returned after flattening the input. Now, if the input values cannot be changed into flattening conditions, then the input values are copied and returned as output. Even if we are flattening the zero-dimensional tensor, a one-dimensional tensor is returned as output. Flatten does not copy the values as input, but it wraps the view function and uses reshape underneath the function.

There are three methods in flattening the tensors using PyTorch. The first method is the oops method where torch.tensor.flatten is used to apply directly to the tensor. Here the code is written as x.flatten(). Another method is the functional method, where the code is written in the format of the torch.flatten. Hence the code to flatten x value will be torch.flatten(x). We have another method where we convert the code in module format. It will be in nn.module format i.e. nn.flatten(). This code is mostly used in the definition of the model.

PyTorch Flatten Examples

Code:

def flatten_ex(self, a):
        a= self.model(a)
        a = torch.flatten(a, start_dim=1)  # flattens the entire line of code
        a = self.full_conn1(a)
        a = self.norm1(a)
        a = F.relu(a)
        a = F.dropout(a)
        a = self.full_conn2(a)
        a = F.relu(a)
        a = F.dropout(a)
        a = self.full_conn3(a)
        return a
def tensor_serializer(worker: AbstractWorker, tensor: torch.Tensor) -> TensorDataPB:
    """Strategy to serialize a tensor using serialization"""
    datatype = TORCH_DTYPE_STR[tensor.dtype]
    serialize_tensor = TensorDataPB()
    if tensor.is_serialized:
        serialize_tensor.is_quantized = True
        serialize_tensor.scale = tensor.q_scale()
        serailize_tensor.zero_point = tensor.q_zero_point()
        data = torch.flatten(tensor).int_repr().tolist()
    else:
        data = torch.flatten(tensor).tolist()
    serialize_tensor.dtype = datatype
    serialize_tensor.shape.dims.extend(tensor.size())
    getattr(serialize_tensor, "contents_" + datatype).extend(data)
    return serialize_tensor
def forward_tensor(self, p):
    p = p.reshape(140, 140, 2)
    p = torch.narrow(p, dim=3, start=1, length=2)
    p = p.reshape(2, 2, 140, 140)
    p = F.avg_pool2d(p, 5, stride=5)
    p = p / 255
    p = (p - MEAN) / STANDARD_DEVIATION
    p = self.conv1(p)
    p = F.relu(p)
    p = self.conv2(p)
    p = F.max_pool2d(p, 2)
    p = self.dropout1(p)
    p = torch.flatten(p, 1)
    p = self.fc1(p)
    p = F.relu(p)
    p = self.dropout2(p)
    p = self.fc2(p)
    result = F.softmax(p, dim=2)
    return result
def  fun_n0(self, a, b, N, dtype):
        n0x, n0y = torch.meshgrid(
            torch.arange(1, a * self.stride + 1, self.stride),
            torch.arange(1, b * self.stride + 1, self.stride))
        n0x = torch.flatten(n0x).view(1, 1, a, b).repeat(1, N, 1, 1)
        n0y = torch.flatten(n0y).view(1, 1, a, b).repeat(1, N, 1, 1)
        n0 = torch.cat([n0x, n0y], 1).type(dtype)
        return n0

PyTorch Flatten parameters

We have only three parameters for PyTorch flatten. They are input, start_dim, and end_dim.

  • Input value ( this is a tensor) – the input tensor which is mostly values where we need to flatten it to one dimension.
  • Start_dim (integer value) – the first dimension in the code to flatten the values
  • End_dim (integer value) – the last dimension in the code to flatten the values

Let us see an example based on this parametric proposition.

Code:

s = torch.tensor([[[29, 28],
                    [27, 26]],
                   [[25, 24],
                    [23, 22]]])
 torch.flatten(s)
tensor([29, 28, 27, 26, 25, 24, 23, 22])
torch.flatten(s, start_dim=1) 
h = torch.randn(16, 2, 15, 54)
f = nn.Sequential(
nn.Conv2d(15, 2, 16, 54, 2),
nn.Flatten()
)
result = f(h)
result.size()

When the module is replicated inside GPU, it is important to use flattened parameters so that the values inside the tensor reduce the dimension and manage the output tensors. But when we use flattened parameters in forwarding pass, more storage is used, and huge compute time is taken for calculations of each value.

PyTorch Flatten Transcript

The first step is to import PyTorch from the libraries.

import torch

The pyTorch version can be checked from the below code.

print(torch.__version__)

We can create a sample of the PyTorch tensor to see the data structure.

pt_example_tensor1 = torch.Tensor(
[
[
[ 3,  1,  4,  5],
[ 2,  8,  6,  7]
]
,
[
[ 13, 10, 9, 15],
[12, 11, 16, 14]
]
,
[
[23, 17, 18, 20],
[19, 21, 24, 23]
]
])

There are two rows and four columns in each matrix. Now let us print the same to know the dimensions of the same.

print(pt_example_tensor1)

The result is four matrices with two rows and four columns, and hence the flattened result will be the result of 24 values.

Let us use the flatten code in the system.

pt_flattened_example_tensor1= pytorch.flatten(pt_example_tensor1)

Now, we can print the result to see whether the flattening in PyTorch worked.

Print(pt_flattened_example_tensor1)

So the result is one dimensional tensor with 24 values where it is reduced from 3 matrices to a single dimension.

Conclusion

If our goal is to flatten the tensor and not necessarily flatten the function itself, then we can use view or reshape and a negative 1 on the code to get the same result. However, when we are sending values to a neural network, we must flatten the values and send them to the network to get the required results. The number of nodes also should be mentioned here.

Recommended Articles

We hope that this EDUCBA information on “PyTorch Flatten” was beneficial to you. You can view EDUCBA’s recommended articles for more information.

  1. What is PyTorch?
  2. PyTorch Versions
  3. Deep Learning Toolbox
  4. Multiple Inheritance in Python

Primary Sidebar

Footer

Follow us!
  • EDUCBA FacebookEDUCBA TwitterEDUCBA LinkedINEDUCBA Instagram
  • EDUCBA YoutubeEDUCBA CourseraEDUCBA Udemy
APPS
EDUCBA Android AppEDUCBA iOS App
Blog
  • Blog
  • Free Tutorials
  • About us
  • Contact us
  • Log in
Courses
  • Enterprise Solutions
  • Free Courses
  • Explore Programs
  • All Courses
  • All in One Bundles
  • Sign up
Email
  • [email protected]

ISO 10004:2018 & ISO 9001:2015 Certified

© 2025 - EDUCBA. ALL RIGHTS RESERVED. THE CERTIFICATION NAMES ARE THE TRADEMARKS OF THEIR RESPECTIVE OWNERS.

EDUCBA

*Please provide your correct email id. Login details for this Free course will be emailed to you
Loading . . .
Quiz
Question:

Answer:

Quiz Result
Total QuestionsCorrect AnswersWrong AnswersPercentage

Explore 1000+ varieties of Mock tests View more

EDUCBA

*Please provide your correct email id. Login details for this Free course will be emailed to you
EDUCBA
Free Software Development Course

Web development, programming languages, Software testing & others

By continuing above step, you agree to our Terms of Use and Privacy Policy.
*Please provide your correct email id. Login details for this Free course will be emailed to you
EDUCBA

*Please provide your correct email id. Login details for this Free course will be emailed to you

EDUCBA Login

Forgot Password?

🚀 Limited Time Offer! - 🎁 ENROLL NOW