Pytorch: Add nn.Flatten Layer

Created on 16 Jul 2017  ·  12Comments  ·  Source: pytorch/pytorch

This layer enables user to write conv layer and fc layer in one nn.Sequential model.

My implementation is at here:
https://gist.github.com/VoVAllen/5531c78a2d3f1ff3df772038bca37a83

Quite simple, but convenient for computer vision beginners.

enhancement nn triaged

Most helpful comment

I understand not wanting to have everything work in a Sequential structure, but this is a super common use case. Basically, any time you use convnets in a Sequential block, you would need such a layer before the final layer. This is a hugely common use-case, not some edge-case.

All 12 comments

we will not be incorporating this layer (as discussed elsewhere in some other issue / PR that i haven't found right now). We dont want to go down the rabbit-hole of trying to shoehorn everything into a Sequential structure, and it's an explicit choice we made. We think adding the additional three lines of adding a forward function is what we want to promote as good practices.

Thanks for reply and I'm curious about the reason. Since I searched "flatten" in the issues but didn't find the related issues. Could you give more hints on how to find the related discussion?

@VoVAllen I think this discussion is relevant https://discuss.pytorch.org/t/non-legacy-view-module/131

@soumith, I have a use case where I want to parse the Pytorch graph and store inbound nodes to specific layers. Since Flatten is in the Forward function, it will not be recorded in the graph trace.

Specifically, I want to create a map where I can store input to specific layer indices. This will require passing input to the torch.jit.get_trace(). This function will fail at Linear op due to size mismatch since the Flatten is defined in the forward function and is not a part of the trace.

@rahul99 Same problem here, do you have any solution?

I understand not wanting to have everything work in a Sequential structure, but this is a super common use case. Basically, any time you use convnets in a Sequential block, you would need such a layer before the final layer. This is a hugely common use-case, not some edge-case.

Searching for "class Flatten" "import torch.nn" yields over 4000 results, most of which seem to be variations of @fmassa's example from the forums.

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

Which makes sense, since as @JulesGM points out, it is a super common thing to do.
Many modules will simply be a "hose" kind of affair: stuff goes in one end and out the other. There's no reason why this has to have more than the forward() method implied by a Sequential

Searching for class "Lambda" "import torch.nn" yields 1400 results. This pattern allows for quick use of lambdas in defining a Sequential Like:

Lambda(lambda x: x.view(x.size(0),-1)))

Reopening given that we reached an agreement in https://github.com/pytorch/pytorch/issues/19682 that we will be implementing this.

Thus, I believe we should add nn.Flatten following the semantics of torch.flatten.

I can probably take this if nobody else is working on it.

@Chillee go for it, I don't think there is anyone working on it

Fixed by #22245

Was this page helpful?
0 / 5 - 0 ratings

Related issues

Coderx7 picture Coderx7  ·  3Comments

negrinho picture negrinho  ·  3Comments

dablyo picture dablyo  ·  3Comments

mishraswapnil picture mishraswapnil  ·  3Comments

bartolsthoorn picture bartolsthoorn  ·  3Comments