Skip to content
Snippets Groups Projects
Commit 7a9d1696 authored by Maximilian Xaver Tiefenbacher's avatar Maximilian Xaver Tiefenbacher
Browse files

Add new file

parents
No related branches found
No related tags found
No related merge requests found
import torch
import sys
import argparse
import numpy as np
#This is just some setup for getting the name of the model if you excecute this file
parser = argparse.ArgumentParser(description='Get the name of the model')
parser.add_argument('m_path', metavar='model_path', type=str, help='enter the name of the model')
args = parser.parse_args()
#The function looks for the model with the name given to it and returns every weight matrix as a numpy array
#if you already know which matrix you are interessted in you can give it to the function as a string and only this matrix will be returned
def model_parser(path,layer=None):
model=torch.load(path,map_location=torch.device('cpu'))
model_dict=model.state_dict()
for name in model_dict:
model_dict[name]=np.array(model_dict[name])
if layer is not None:
return model_dict[layer]
else:
return model_dict
if __name__ == "__main__":
x=model_parser(args.m_path)
for name in x:
print(name)
print("shape: "+str(x[name].shape))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment