What are the differences between model.tran() and model.eval()

Last updated on:5 months ago

The model.train() and model.eval() are the setting for network tuning on PyTorch. Someone may ask why do we need them for a same network?

Different circumstances

Actually, some function we want on the training (compute gradient) such as batch normalization, we don’t use it on validation (don’t need to compute gradient).

model.train() tells your model that you are training the model. So effectively layers like dropout, batchnorm etc. which behave different on the train and test procedures know what is going on and hence can behave accordingly. Use model.training flag. It is False, when in eval mode.

model.train()
self.training = 1

model.eval()
self.training = 0

We should not use model.train() and model.eval() set to plot ROC?

If the algorithm never saw it, then no, that’s not overfit. Overfitting means an algorithm fits the train data (and validation data, if it was used) really really well, while it has really bad generalization.

We use model.eval() during validation. The well-tuned model also doesn’t use to validation set to train itself directly. We have to get the final output like (# classes, 1) tensor to get the final possibility of each class. And then ROC is calculated based on that final possibility.

Reference

What does model.train() do in PyTorch?

AUROC equal to 1.0 means overfitting?