Prostate Neural Network

Incoming Domain/Website Update

Imminent in...

Prostate MRI Neural Network First Go

@author: chine

"""

import matplotlib


import matplotlib.pyplot as plt

import torch, os

import torch.nn as nn

import torch.optim as optim

import torchvision

import torchvision.datasets as datasets

import torch.utils.data as data

import torchvision.transforms as transforms

import torchvision.models as models


import numpy as np

import nibabel as nib


from captum.attr import IntegratedGradients

from captum.attr import Saliency

# from captum.attr import DeepLift

from captum.attr import NoiseTunnel

# from captum.attr import visualization as viz

from captum.attr import GuidedBackprop

from captum.attr import InputXGradient


def imshow(img, transpose = True):

    img = img / 2 + 0.5     # unnormalize

    npimg = img.numpy()

    plt.imshow(np.transpose(npimg, (1, 2, 0)))

    # plt.show()

        

def normalize_image(image_3d):

    vmin = torch.min(image_3d)

    image_2d = image_3d - vmin

    vmax = torch.max(image_2d)

    return (image_2d / vmax)


class Dataloader_Decathlon(data.Dataset):

    def __init__(self, path, transform = None, test = False, val = False, split = 1.0, ):

        # load data in list

        self.test = test

        self.transform = transform

        if test:

            img_folder = '//imagesTs'

        else:

            img_folder = '//imagesTr'

            labels_folder = '//labelsTr'

        dir_list = os.listdir(str(path + img_folder))

        self.images_list = []

        for data_path in dir_list:

            try:

                self.images_list.append(nib.load(str(path + img_folder + '//' + data_path)))

            except:

                print("No access")           

        

        if test == False:

            dir_list = os.listdir(str(path + labels_folder))

            self.labels_list = []

            for data_path in dir_list[3:]:

                try:

                    self.labels_list.append(nib.load(str(path + labels_folder + '//' + data_path)))

                except:

                    print("No access")

        

        if test == False:

            if val:

                self.images_list = self.images_list[:int(len(self.images_list) * split)]

                self.labels_list = self.labels_list[:int(len(self.labels_list) * split)]

            else:

                self.images_list = self.images_list[int(len(self.images_list) * split):]

                self.labels_list = self.labels_list[int(len(self.labels_list) * split):]

        

    def __len__(self):

        return len(self.images_list)


    def __getitem__(self, idx):

        nii_image = self.images_list[idx]

        data = torch.from_numpy(np.asarray(nii_image.dataobj))

        if self.test == False:

            nii_label = self.labels_list[idx]

            label_ = torch.from_numpy(np.asarray(nii_label.dataobj))

            label = torch.tensor([int(int(torch.sum(label_[:,:,i])) > 0) for i in range(label_.shape[2])])

        

        if self.transform != None:

            data = data[:,:,:,0:1]

            data_ = torch.tensor([])

            

            for i in range(data.shape[2]):

                d = transforms.ToPILImage(mode='L')(data[:,:,i,0])

                data_ = torch.cat((data_, self.transform(d)), 0)

            

            data = np.transpose(data, (3, 2, 0, 1))

            data = torch.cat([data, data, data], dim=0)

            npimg = data.numpy()

            data = torch.tensor(np.transpose(npimg, (1, 0, 2, 3)))

            


            return data,label

        else:

            return data, label



path_data = 'C://Users//chine//Desktop//PythonScripts//captum_tutorial//Task05_Prostate'

# pre processing

# transform_ = transforms.Compose([transforms.ToTensor()])


transform_ = transforms.Compose([

    transforms.ToTensor(),

    transforms.RandomHorizontalFlip(0.5),

    transforms.RandomVerticalFlip(.5),

    transforms.RandomAutocontrast(.25),

    transforms.Normalize(mean=[.5], std=[.5]),

    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))

])


train_dataloader = Dataloader_Decathlon(path_data, 

                                        transform = transform_, 

                                        test = False, val = False, split = 0.9)

val_dataloader = Dataloader_Decathlon(path_data,

                                      transform = transform_, 

                                      test = False, val = True, split = 0.1)

test_dataloader = Dataloader_Decathlon(path_data, 

                                       transform = transform_, 

                                       test = True)



scan_num = 2


data, label = train_dataloader.__getitem__(scan_num)

print(label)

print(data.shape)


fig=plt.figure(figsize=(24, 10))

length, width = 3, int(len(label) / 3) + 1

if transform_ != None:

    num_slices = data.shape[0]

else: 

    num_slices = data.shape[2]

for i in range(num_slices):

    fig.add_subplot(length, width, i + 1).set_title(str('Slice: ' + str(i) + ' Label: ' + str(label[i])))

    if transform_ != None:

        data_slice = data[i:i+1,0,:,:]

        data_slice = torch.cat([data_slice, data_slice, data_slice], dim=0)

        imshow(normalize_image(data_slice))

    else: 

        data_slice = data[:,:,i:i+1,0]

        plt.imshow(data_slice)

    

    data_slice.shape


plt.show()



# =============================================================================

# to load data use:

#     

#     for itr in range(train_dataloader.__len__()):

#         image, label = train_dataloader.__getitem__(itr)

# instead of:

#     

#     for itr, (image, label) in enumerate(train_dataloader):

# for train_dataloader, val_dataloader and testdataloader

# =============================================================================


classes = ('no prostate', 'prostate') #define the 2 classes


#define a classical neural network


net = torchvision.models.resnet18(pretrained=True)

num_ftrs = net.fc.in_features

net.fc = torch.nn.Linear(num_ftrs, len(classes))


criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=.0001)


#TRAIN MODEL "NET", for 5 epochs


USE_PRETRAINED_MODEL = True


if USE_PRETRAINED_MODEL:

    print("Using existing trained model")

    net.load_state_dict(torch.load('models/decathalon_net.pt'))

else:

    

    for epoch in range(10):  # loop over the dataset multiple times

        correct = 0

        total = 0

        running_loss = 0.0

        for itr in range(train_dataloader.__len__()):

            inputs, label = train_dataloader.__getitem__(itr)

            # get the inputs

            # zero the parameter gradients

            optimizer.zero_grad()


            # forward + backward + optimize

            outputs = net(inputs)

            loss = criterion(outputs, label)

            loss.backward()

            optimizer.step()


            # print statistics

            running_loss += loss.item()

            if itr % 2000 == 1999:    # print every 2000 mini-batches

                print('[%d, %5d] loss: %.3f' %

                      (epoch + 1, itr + 1, running_loss / 2000))

                running_loss = 0.0

            for i, p in enumerate(outputs):

                if label[i] == torch.max(p.data, 0)[1]:

                    correct = correct + 1

                total = total + 1

                

        # How to combine strings for printing

        #print(str('This is the accuracy: ' + str(correct/total)))

        # Or

        print('Train Accuracy: ', correct/total)

            

        valcorrect = 0

        valtotal = 0

        

        for itr in range(val_dataloader.__len__()):

            inputs, label = val_dataloader.__getitem__(itr)

            # get the inputs

            # zero the parameter gradients

            optimizer.zero_grad()


            # forward + backward + optimize

            outputs = net(inputs)

            loss = criterion(outputs, label)

            # loss.backward()

            # optimizer.step()


            # print statistics

            running_loss += loss.item()

            if itr % 2000 == 1999:    # print every 2000 mini-batches

                print('[%d, %5d] loss: %.3f' %

                      (epoch + 1, itr + 1, running_loss / 2000))

                running_loss = 0.0

            for i, p in enumerate(outputs):

                if label[i] == torch.max(p.data, 0)[1]:

                    valcorrect = valcorrect + 1

                valtotal = valtotal + 1

                

        # How to combine strings for printing

        #print(str('This is the accuracy: ' + str(correct/total)))

        # Or

        print('Validation Accuracy: ', valcorrect/valtotal)

        

    print('Finished Training')

    torch.save(net.state_dict(), 'models/decathalon_net.pt')


# dataiter = iter(test_dataloader)

# images, labels = dataiter


ind = 10

                 

plt.show(label[ind])

                                                                                                                                                                                                                                       

input = data[ind].unsqueeze(0)

input.requires_grad = True


def attribute_image_features(algorithm, input, **kwargs):

     net.zero_grad()

     tensor_attributions = algorithm.attribute(input,target=label[ind],**kwargs)

     return tensor_attributions

 

fig=plt.figure(figsize=(15, 15))

length, width = 1, 5


fig.add_subplot(length, width, 1).set_title('Original')

imshow(data[ind:ind+1,0,:,:])



print('Saliency Gradient')

saliency = Saliency(net)

grads = saliency.attribute(input, 1)

fig.add_subplot(length, width, 2).set_title('Saliency')

imshow(torchvision.utils.make_grid(normalize_image(grads)))

# grads = np.transpose(grads.squeeze().cpu().detach().numpy(), (1, 2, 0))


print('saliency-NT')

# Generate Noisetunnel

# ig = IntegratedGradients(net)

nt = NoiseTunnel(saliency)

attr_saliency_nt = attribute_image_features(nt, input, nt_type='smoothgrad_sq',nt_samples=5, stdevs=0.25)

fig.add_subplot(length, width, 3).set_title('Saliency-NT')

imshow(torchvision.utils.make_grid(normalize_image(attr_saliency_nt)))

# attr_ig_nt = np.transpose(attr_ig_nt.squeeze(0).cpu().detach().numpy(), (1, 2, 0))


# print('IG')

# # Generate integrated gradients attribution

# ig = IntegratedGradients(net)

# attr_ig, delta = attribute_image_features(ig, input, baselines=input * 0, return_convergence_delta=True)

# # attr_ig = np.transpose(attr_ig.squeeze().cpu().detach().numpy(), (1, 2, 0))

# fig.add_subplot(length, width, 4).set_title('Integraded Gradient')

# imshow(torchvision.utils.make_grid(normalize_image(attr_ig.detach().clone())))

# # attr_ig = np.transpose(attr_ig.squeeze().cpu().detach().numpy(), (1, 2, 0))

# plt.show()


# Input X

print('Input X')

input_x_gradient = InputXGradient(net)

xgrads = input_x_gradient.attribute(input, 1)

fig.add_subplot(length, width, 4).set_title('Input X Grad')

imshow(torchvision.utils.make_grid(normalize_image(xgrads.detach().clone())))

# xgrads = np.transpose(xgrads.squeeze().cpu().detach().numpy(), (1, 2, 0))


# =============================================================================

# to load data use:

#     

#     for itr in range(train_dataloader.__len__()):

#         image, label = train_dataloader.__getitem__(itr)

# instead of:

#     

#     for itr, (image, label) in enumerate(train_dataloader):

# for train_dataloader, val_dataloader and testdataloader

# =============================================================================

 ... Link to Social's ...

InstagramYouTubeSnapchatTikTok