Prostate Neural Network
Incoming Domain/Website Update
Imminent in...
Prostate MRI Neural Network First Go
@author: chine
"""
import matplotlib
import matplotlib.pyplot as plt
import torch, os
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
import nibabel as nib
from captum.attr import IntegratedGradients
from captum.attr import Saliency
# from captum.attr import DeepLift
from captum.attr import NoiseTunnel
# from captum.attr import visualization as viz
from captum.attr import GuidedBackprop
from captum.attr import InputXGradient
def imshow(img, transpose = True):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
# plt.show()
def normalize_image(image_3d):
vmin = torch.min(image_3d)
image_2d = image_3d - vmin
vmax = torch.max(image_2d)
return (image_2d / vmax)
class Dataloader_Decathlon(data.Dataset):
def __init__(self, path, transform = None, test = False, val = False, split = 1.0, ):
# load data in list
self.test = test
self.transform = transform
if test:
img_folder = '//imagesTs'
else:
img_folder = '//imagesTr'
labels_folder = '//labelsTr'
dir_list = os.listdir(str(path + img_folder))
self.images_list = []
for data_path in dir_list:
try:
self.images_list.append(nib.load(str(path + img_folder + '//' + data_path)))
except:
print("No access")
if test == False:
dir_list = os.listdir(str(path + labels_folder))
self.labels_list = []
for data_path in dir_list[3:]:
try:
self.labels_list.append(nib.load(str(path + labels_folder + '//' + data_path)))
except:
print("No access")
if test == False:
if val:
self.images_list = self.images_list[:int(len(self.images_list) * split)]
self.labels_list = self.labels_list[:int(len(self.labels_list) * split)]
else:
self.images_list = self.images_list[int(len(self.images_list) * split):]
self.labels_list = self.labels_list[int(len(self.labels_list) * split):]
def __len__(self):
return len(self.images_list)
def __getitem__(self, idx):
nii_image = self.images_list[idx]
data = torch.from_numpy(np.asarray(nii_image.dataobj))
if self.test == False:
nii_label = self.labels_list[idx]
label_ = torch.from_numpy(np.asarray(nii_label.dataobj))
label = torch.tensor([int(int(torch.sum(label_[:,:,i])) > 0) for i in range(label_.shape[2])])
if self.transform != None:
data = data[:,:,:,0:1]
data_ = torch.tensor([])
for i in range(data.shape[2]):
d = transforms.ToPILImage(mode='L')(data[:,:,i,0])
data_ = torch.cat((data_, self.transform(d)), 0)
data = np.transpose(data, (3, 2, 0, 1))
data = torch.cat([data, data, data], dim=0)
npimg = data.numpy()
data = torch.tensor(np.transpose(npimg, (1, 0, 2, 3)))
return data,label
else:
return data, label
path_data = 'C://Users//chine//Desktop//PythonScripts//captum_tutorial//Task05_Prostate'
# pre processing
# transform_ = transforms.Compose([transforms.ToTensor()])
transform_ = transforms.Compose([
transforms.ToTensor(),
transforms.RandomHorizontalFlip(0.5),
transforms.RandomVerticalFlip(.5),
transforms.RandomAutocontrast(.25),
transforms.Normalize(mean=[.5], std=[.5]),
transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))
])
train_dataloader = Dataloader_Decathlon(path_data,
transform = transform_,
test = False, val = False, split = 0.9)
val_dataloader = Dataloader_Decathlon(path_data,
transform = transform_,
test = False, val = True, split = 0.1)
test_dataloader = Dataloader_Decathlon(path_data,
transform = transform_,
test = True)
scan_num = 2
data, label = train_dataloader.__getitem__(scan_num)
print(label)
print(data.shape)
fig=plt.figure(figsize=(24, 10))
length, width = 3, int(len(label) / 3) + 1
if transform_ != None:
num_slices = data.shape[0]
else:
num_slices = data.shape[2]
for i in range(num_slices):
fig.add_subplot(length, width, i + 1).set_title(str('Slice: ' + str(i) + ' Label: ' + str(label[i])))
if transform_ != None:
data_slice = data[i:i+1,0,:,:]
data_slice = torch.cat([data_slice, data_slice, data_slice], dim=0)
imshow(normalize_image(data_slice))
else:
data_slice = data[:,:,i:i+1,0]
plt.imshow(data_slice)
data_slice.shape
plt.show()
# =============================================================================
# to load data use:
#
# for itr in range(train_dataloader.__len__()):
# image, label = train_dataloader.__getitem__(itr)
#
# instead of:
#
# for itr, (image, label) in enumerate(train_dataloader):
#
# for train_dataloader, val_dataloader and testdataloader
# =============================================================================
classes = ('no prostate', 'prostate') #define the 2 classes
#define a classical neural network
net = torchvision.models.resnet18(pretrained=True)
num_ftrs = net.fc.in_features
net.fc = torch.nn.Linear(num_ftrs, len(classes))
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=.0001)
#TRAIN MODEL "NET", for 5 epochs
USE_PRETRAINED_MODEL = True
if USE_PRETRAINED_MODEL:
print("Using existing trained model")
net.load_state_dict(torch.load('models/decathalon_net.pt'))
else:
for epoch in range(10): # loop over the dataset multiple times
correct = 0
total = 0
running_loss = 0.0
for itr in range(train_dataloader.__len__()):
inputs, label = train_dataloader.__getitem__(itr)
# get the inputs
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, label)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if itr % 2000 == 1999: # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, itr + 1, running_loss / 2000))
running_loss = 0.0
for i, p in enumerate(outputs):
if label[i] == torch.max(p.data, 0)[1]:
correct = correct + 1
total = total + 1
# How to combine strings for printing
#print(str('This is the accuracy: ' + str(correct/total)))
# Or
print('Train Accuracy: ', correct/total)
valcorrect = 0
valtotal = 0
for itr in range(val_dataloader.__len__()):
inputs, label = val_dataloader.__getitem__(itr)
# get the inputs
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, label)
# loss.backward()
# optimizer.step()
# print statistics
running_loss += loss.item()
if itr % 2000 == 1999: # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, itr + 1, running_loss / 2000))
running_loss = 0.0
for i, p in enumerate(outputs):
if label[i] == torch.max(p.data, 0)[1]:
valcorrect = valcorrect + 1
valtotal = valtotal + 1
# How to combine strings for printing
#print(str('This is the accuracy: ' + str(correct/total)))
# Or
print('Validation Accuracy: ', valcorrect/valtotal)
print('Finished Training')
torch.save(net.state_dict(), 'models/decathalon_net.pt')
# dataiter = iter(test_dataloader)
# images, labels = dataiter
ind = 10
plt.show(label[ind])
input = data[ind].unsqueeze(0)
input.requires_grad = True
def attribute_image_features(algorithm, input, **kwargs):
net.zero_grad()
tensor_attributions = algorithm.attribute(input,target=label[ind],**kwargs)
return tensor_attributions
fig=plt.figure(figsize=(15, 15))
length, width = 1, 5
fig.add_subplot(length, width, 1).set_title('Original')
imshow(data[ind:ind+1,0,:,:])
print('Saliency Gradient')
saliency = Saliency(net)
grads = saliency.attribute(input, 1)
fig.add_subplot(length, width, 2).set_title('Saliency')
imshow(torchvision.utils.make_grid(normalize_image(grads)))
# grads = np.transpose(grads.squeeze().cpu().detach().numpy(), (1, 2, 0))
print('saliency-NT')
# Generate Noisetunnel
# ig = IntegratedGradients(net)
nt = NoiseTunnel(saliency)
attr_saliency_nt = attribute_image_features(nt, input, nt_type='smoothgrad_sq',nt_samples=5, stdevs=0.25)
fig.add_subplot(length, width, 3).set_title('Saliency-NT')
imshow(torchvision.utils.make_grid(normalize_image(attr_saliency_nt)))
# attr_ig_nt = np.transpose(attr_ig_nt.squeeze(0).cpu().detach().numpy(), (1, 2, 0))
# print('IG')
# # Generate integrated gradients attribution
# ig = IntegratedGradients(net)
# attr_ig, delta = attribute_image_features(ig, input, baselines=input * 0, return_convergence_delta=True)
# # attr_ig = np.transpose(attr_ig.squeeze().cpu().detach().numpy(), (1, 2, 0))
# fig.add_subplot(length, width, 4).set_title('Integraded Gradient')
# imshow(torchvision.utils.make_grid(normalize_image(attr_ig.detach().clone())))
# # attr_ig = np.transpose(attr_ig.squeeze().cpu().detach().numpy(), (1, 2, 0))
# plt.show()
# Input X
print('Input X')
input_x_gradient = InputXGradient(net)
xgrads = input_x_gradient.attribute(input, 1)
fig.add_subplot(length, width, 4).set_title('Input X Grad')
imshow(torchvision.utils.make_grid(normalize_image(xgrads.detach().clone())))
# xgrads = np.transpose(xgrads.squeeze().cpu().detach().numpy(), (1, 2, 0))
# =============================================================================
# to load data use:
#
# for itr in range(train_dataloader.__len__()):
# image, label = train_dataloader.__getitem__(itr)
#
# instead of:
#
# for itr, (image, label) in enumerate(train_dataloader):
#
# for train_dataloader, val_dataloader and testdataloader
# =============================================================================