File size: 378 Bytes
f6ca457 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
import torchvision
def create_vgg_model():
# Create model
model_weights = torchvision.models.VGG19_Weights.DEFAULT
model = torchvision.models.vgg19(weights=model_weights)
# Freeze layers
for param in model.parameters():
param.requires_grad = False
# Kepp only the features of the model
model = model.features
return model
|