File size: 378 Bytes
f6ca457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19

import torchvision

def create_vgg_model():
   
    # Create model
    model_weights = torchvision.models.VGG19_Weights.DEFAULT
    model = torchvision.models.vgg19(weights=model_weights)
    

    # Freeze layers
    for param in model.parameters():
      param.requires_grad = False

    # Kepp only the features of the model
    model = model.features
    
    return model