Machine Learning pl

Wykrywanie chorób płuc na zdjęciach RTG – praktyczne zastosowanie uczenia maszynowego

Podczas poszukiwań inspiracji do nowego projektu natrafiłem na interesujący zbiór danych Chest X-Ray Images (Pneumonia), obejmujący blisko 6 000 zdjęć rentgenowskich, przedstawiających płuca, zarówno osób zdrowych, jak i pacjentów z zapaleniem o podłożu bakteryjnym lub wirusowym [1]. Kolekcja została starannie skategoryzowana i opatrzona przykładowym algorytmem, który umożliwia trenowanie modelu zdolnego do automatycznego rozpoznawania zmian zapalnych na obrazach RTG.

Temat wydał mi się na tyle ciekawy, że postanowiłem go zgłębić i samodzielnie wytrenować model na własnym sprzęcie. Ale jak to zwykle bywa, całość okazała się dużo bardziej złożona, niż pierwotnie przypuszczałem a prawdziwe problemy miały dopiero nadejść, o czym na tym etapie jeszcze nie wiedziałem.

W trakcie prac okazało się, że dołączony do zdjęć kod nie chciał współpracować z nowszymi wersjami bibliotek Pythona, ponieważ część funkcjonalności została z nich albo usunięta, albo znacząco przekształcona. To, co niegdyś uznawano za funkcje eksperymentalne, dziś stało się standardem, natomiast inne elementy przestały być wspierane.

Aby temu zaradzić, postanowiłem przepisać część kodu od podstaw, a niektóre fragmenty zapożyczyć z innych źródeł [5]. Finalna wersja projektu stała się więc niejako „zupą” kodu, na którą naniosłem własne poprawki.

Pomimo mojego wkładu w projekt, nie przypisuję sobie zasług za oryginalne rozwiązania. Ich autorami pozostają programiści, którzy zajęli się tym zagadnieniem przede mną. Moim celem było raczej połączenie wszystkiego w spójną, nowoczesną całość i dostosowanie do współczesnych realiów. Dodatkowo, ponieważ spędziłem nad kodem wiele godzin, wytrenowany w jego efekcie model stał się raczej wartością dodaną, niż celem samym w sobie.

Przejdźmy teraz do części technicznej i przyjrzyjmy się, na jakiej zasadzie automat podejmuje decyzje oraz jak krok po kroku przebiega cały proces klasyfikacji.

Przeglądając załączone zdjęcia, już na pierwszy rzut oka zauważyć można pewną prawidłowość, albowiem obrazy płuc zdrowych wyglądają klarownie, z przeważającym kolorem czarnym, natomiast płuc zainfekowanych posiadają w sobie spore, jasne zmętnienie.

Źródło: Visualization and Interpretation of Convolutional Neural Network Predictions in Detecting Pneumonia in Pediatric Chest Radiographs [7]

Wikipedia w taki sposób wyjaśnia takie zjawisko:

Ważnym testem do postawienia diagnozy zapalenia płuc jest prześwietlenie klatki piersiowej. Zdjęcia takie mogą ujawnić obszary zmętnienia (widoczne jako białe), które reprezentują konsolidację.[2]

Konsolidacja to termin kliniczny oznaczający zestalenie się w zwartą, gęstą masę. […] Skonsolidowana tkanka jest nieprzepuszczalna dla promieni rentgenowskich, dzięki czemu jest wyraźnie widoczna na zdjęciach rentgenowskich i tomografii komputerowej. [3]

Cecha ta będzie istotnym elementem podczas rozpoznawania stanu płuc przez algorytm.

Przejdźmy teraz do omówienia co zawiera paczka danych Chest X-Ray Images [1]. Jest to duży, bo liczący 2 GB plik ZIP, który po rozpakowaniu zawiera następujące katalogi:

  • test (zdjęcia do testowania modelu)
  • train (zdjęcia do trenowania modelu)
  • val (zdjęcia do walidacji/sprawdzenia)

W każdym z powyższych katalogów znajdują się podkatalogi:

  • NORMAL (zdjęcia płuc zdrowych)
  • PNEUMONIA (zdjęcia płuc z wirusowym lub bakteryjnym stanem zapalnym)

Aby móc pracować z danymi, załadujmy wszystkie zdjęcia, z podkatalogów folderu:

\\Users\\Karol\\UczenieMaszynowe\\chest_xray\\

do do pythonowych list (tablic).

path_to_data='\\Users\\Karol\\UczenieMaszynowe\\chest_xray\\'

train_path_data = path_to_data+'train\\'
valid_path_data = path_to_data+'val\\'
test_path_data = path_to_data+'test\\'

train_images_normal = os.listdir(train_path_data+labels[0]+'\\')
train_images_pneumonia =   os.listdir(train_path_data+labels[1]+'\\')
    
valid_images_normal = os.listdir(valid_path_data+labels[0]+'\\')
valid_images_pneumonia = os.listdir(valid_path_data+labels[1]+'\\')

test_images_normal = os.listdir(test_path_data+labels[0]+'\\')
test_images_pneumonia = os.listdir(test_path_data+labels[1]+'\\')
    
print(train_images_normal)

Testowo wyświetlmy zawartość zmiennej train_images_normal, która jest listą zawierają kolejne nazwy plików w folderze.

W przypadku podkatalogu NORMAL lista wygląda następująco:

a tak wygląda odzwierciedlenie w fizycznych plików na dysku:

W przypadku katalogu PNEUMONIA wygląda to analogicznie:

print('train_images_normal:', len(train_images_normal))
print('train_images_pneumonia:', len(train_images_pneumonia))

print('valid_images_normal:', len(valid_images_normal))
print('valid_images_pneumonia:', len(valid_images_pneumonia))

print('test_images_normal:', len(test_images_normal))
print('test_images_pneumonia:', len(test_images_pneumonia))

Za pomocą funkcji len(), sprawdźmy długość powstałych list, aby poznać dane z którymi będziemy pracować (ich skalę). Doświadczenie to pokazuje, że danych treningowych jest odpowiednio:

  • 1 341 płuc zdrowych
  • 3 875 płuc chorych.

Danych do testów jest 234/390, natomiast danych walidujących jest (niestety) tylko po 8.

W kolejnym kroku zapoznajmy się z rozmiarem zdjęć w pikselach. Służy do tego funkcja imege_size_in_pixels(), która przyjmuje następujące parametry:

  • folder (wartości to: val, train lub test)
  • labels (przyjmuje wartości: NORMAL lub PNEUMONIA)
def image_size_in_px(folder, labels):
    
    print('*********** Folder: ', folder, '**********')

    im_shape_x_lists_n = []
    im_shape_x_lists_p = []
    im_shape_y_lists_n = []
    im_shape_y_lists_p = []
    
    if folder=='val':
        path = valid_path_data
        normal = valid_im_n
        pneumonia = valid_im_p
    elif folder=='train':
        path = train_path_data
        normal = train_im_n
        pneumonia = train_im_p
    else:
        path = test_path_data
        normal = test_im_n
        pneumonia = test_im_p
        
    for i, img in enumerate(normal):
        sample = os.path.join(path+labels[0]+'\\', img)
        sample_img = Image.open(sample)
        w, h = sample_img.size
        print('Plik:', sample, 'Rozmiar:',w,'x',h,'px')
        im_shape_x_lists_n.append(w)
        im_shape_y_lists_n.append(h)
        
    for i, img in enumerate(pneumonia):
        sample = os.path.join(path+labels[1]+'\\', img)
        sample_img = Image.open(sample)
        w, h = sample_img.size
        im_shape_x_lists_p.append(w)
        im_shape_y_lists_p.append(h)
        
    return im_shape_x_lists_n, im_shape_y_lists_n, im_shape_x_lists_p, im_shape_y_lists_p

Wywołajmy powyższą funkcję:

im_shape_valid_x_n, im_shape_valid_y_n, im_shape_valid_x_p, im_shape_valid_y_p = image_size_in_px('val', labels)
im_shape_train_x_n, im_shape_train_y_n, im_shape_train_x_p, im_shape_train_y_p = image_size_in_px('train', labels)

im_shape_test_x_n, im_shape_test_y_n, im_shape_test_x_p, im_shape_test_y_p = image_size_in_px('test', labels) 

W ciele funkcji zaszyłem kilka instrukcji print() aby wyświetlić odpowiedź:

[…]

Wizualizacja danych

Skoro dane mamy już załadowane do obiektów Pythona, przyszedł czas, na ich wizualizację. W przypadku płuc zdrowych, użyjemy zmiennej train_im_n, będącą listą zawierającą nazwy plików w podkatalogu NORMAL):

fig = plt.figure(figsize=(15, 10))
npics= 12count = 1

train_im_n_selected = random.sample(train_im_n, 12)for i, img in enumerate(train_im_n_selected):
    sample = os.path.join(train_path_data +labels[0]+'/', img) 
    sample_img = Image.open(sample)   
    sample_img = np.array(sample_img)
    sample_img = sample_img/255.0
    ax = fig.add_subplot(int(npics/3) , 4, count, xticks=[],yticks=[])   
    plt.imshow(sample_img, cmap='gray')
    plt.colorbar()
    count +=1
    
fig.suptitle('Normal')
plt.tight_layout()
plt.show()

W przypadku płuc zainfekowanych, użyjemy zmiennej train_im_p, zawierającą nazwy plików w podkatalogu PNEUMONIA):

fig = plt.figure(figsize=(15, 10))
npics= 12count = 1

train_im_p_selected = random.sample(train_im_p, 12)for i, img in enumerate(train_im_p_selected):
    sample = os.path.join(train_path_data +labels[1]+'/', img) 
    sample_img = Image.open(sample)   
    ax = fig.add_subplot(int(npics/3), 4, count, xticks=[],yticks=[])   
    plt.imshow(sample_img, cmap='gray')
    count +=1
    
fig.suptitle('Pneumonia')
plt.tight_layout()
plt.show()

Trenowanie modelu

Poddajmy nasze dane standaryzacji. Standaryzacja danych w machine learning, polega na przekształceniu danych pierwotnych, aby ich rozkład miał średnią wartość równą 0 i odchylenie standardowe równe 1.

Zabieg taki stosuje się zazwyczaj wtedy, kiedy korelujemy ze sobą zmienne o zupełnie różnych skalach wielkości. Przykładowo, kiedy jedna zmienna oscyluje pomiędzy wartościami od 1 do 10 a druga w granicach milionów, czyli o kilka rzędów wielkości wyższych. Innym powodem przemawiającym za standaryzacją, jest chęć pracy na małych liczbach, których obsługa jest mniej kłopotliwa i szybsza. Po wystandaryzowaniu zmiennych, skala ich wartości przestaje mieć znaczenie, a istotny pozostaje jedynie ich rozrzut, czyli wariancja.

Zobaczmy, jak standaryzacja pomaga rozłożyć wartości pikseli w kilku losowych przykładach.

fig = plt.figure(figsize=(15, 15))
count=1
#print('train_im_n_selected', train_im_n_selected)

for i, img in enumerate(train_im_n_selected):

    # Pasek postepu    
    print('#',end="")
    
    sample_one = os.path.join(train_path_data +labels[0]+'/', img)
    
    #print('sample_one: ', sample_one)
    
    sample_img = Image.open(sample_one)  
        
    sample_img = np.array(sample_img)
    
    #print('sample_img = np.array(sample_img):', sample_img)
    
    sample_img = sample_img/255.0
    
    #print('sample_img = sample_img/255.0', sample_img)
    
    sample_img_mean = np.mean(sample_img)
    
    #print('sample_img_mean', sample_img_mean)
    
    #sample_img_std = np.std(sample_img)
    
    new_sample_img = (sample_img - sample_img_mean)/sample_img_std
    
    ax = fig.add_subplot(int(npics/2) , 3, count, yticks=[])
    
    sns.histplot(new_sample_img.ravel(), 
             label=f'Średnia piksela A {np.mean(new_sample_img):.2f} & Std. A {np.std(new_sample_img):.2f}', kde=False, color='blue', bins=35, alpha=0.8)
    sns.histplot(sample_img.ravel(), 
             label=f'Średnia piksela B {np.mean(sample_img):.2f} & Std. B {np.std(sample_img):.2f}', kde=False, color='red', bins=35, alpha=0.8)
    
    plt.legend(loc='upper center', fontsize=9)
    plt.title('Nazwa pliku: %s'% (img))
    plt.xlabel('Intensywność pikseli')
    plt.ylabel('Piksele w obrazie')
    
    
    count +=1
    
fig.suptitle('Rozklad intensywności pikseli (przed i po standaryzacji)')
plt.tight_layout()
plt.show()

Teraz zrobiło się naprawdę ciekawie…

Zmienna train_im_n_selected jest listą zawierającą nazwy plików treningowych z podkatalogu NORMAL:

['IM-0742–0001.jpeg', 'NORMAL2-IM-0870–0001.jpeg', 'IM-0265–0001.jpeg',…

Rozbijmy w pętli listę na poszczególne nazwy plików, które przechowamy w zmiennej sample_one.

sample_one = '\Users\Karol\UczenieMaszynowe\chest_xray\train\NORMAL\IM-0742–0001.jpeg'

W kolejnych liniach otwieramy zdjęcie, następnie importujemy wartości poszczególnych kodów kolorów każdego z pikseli do listy numpy.

[[ 0  0  0 ... 42 40 40]
 [ 0  0  0 ... 43 44 42]
 [ 0  0  0 ... 41 43 42]
 ...
 [ 0  0  0 ...  0  0  0]
 [ 0  0  0 ...  0  0  0]
 [ 0  0  0 ...  0  0  0]]

I teraz najciekawsze: Dzielimy każdą wartość przez 255.

sample_img = sample_img/255.0

Czemu? Otóż wartości kodów kolorów pikseli zawierają się w przedziale od 0 do 256.
Podczas korzystania z obrazu w sieci neuronowej, operowanie na dużych liczbach może stać się bardzo złożone. Aby to uprościć, powinniśmy znormalizować wartości do zakresu od 0 do 1, przy zachowaniu takiego samego rozkładu. Skoro wartości pikseli mieszczą się w zakresie od 0 do 256, więc pomijając liczbę 0, ich zakres wynosi 255. Podzielenie wszystkich wartości przez 255 dokona konwersji na zakres od 0 do 1.

Po takiej operacji nasza macierz przyjmie następujące wartości:

[[0.  0.  0. ... 0.16470588 0.15686275 0.15686275] 
 [0.  0.  0. ... 0.16862745 0.17254902 0.16470588]
 [0.  0.  0. ... 0.16078431 0.16862745 0.16470588]
 ...
 [0.  0.  0. ... 0.         0.         0.        ]
 [0.  0.  0. ... 0.         0.         0.        ]
 [0.  0.  0. ... 0.         0.         0.        ]]
Czyli przykładowo: 
24 / 255 = 0,1647058823529412
40 / 255 = 0,1568627450980392

Na poniższym histogramie widać porównanie rozkładu zestawów danych przed (wykres niebieski) i po standaryzacji (wykres czerwony). Jak widać wariancja obu wykresów wygląda podobnie, różnica polega na tym, że dane wystandaryzowane zawierają się w zakresie od 0 do 1.

Aby uwzględnić ten typ standaryzacji, utwórzmy funkcję, która będzie wykorzystywana jako warstwa lambda w budowaniu modelu.

  • Funkcja reduce_mean() służy do znajdowania średniej elementów w wymiarach tensora.
  • Funkcja reduce_std() służy do znajdowania odchylenia standardowego elementów w wymiarach tensora.
def standardize_layer(tensor):
    tensor_mean = tf.math.reduce_mean(tensor)
    tensor_std = tf.math.reduce_std(tensor)
    new_tensor = (tensor-tensor_mean)/tensor_std
    return new_tensor

Jako, że obrazy są dostępne fizycznie w katalogach na dysku, możemy w tym wypadku użyć funkcji image_dataset_from_directory. Funkcja ta Generuje tf.data.Dataset z plików obrazów z katalogu:

target_size = (300, 300)
input_shape = (300, 300, 1)
batch_size = 64

print("Dataset treningowy")
train_dir = tf.keras.preprocessing.image_dataset_from_directory(
         '\\Users\\Karol\\UczenieMaszynowe\\chest_xray\\train\\', 
         image_size=target_size, 
         batch_size=batch_size,
         shuffle=True,
         color_mode='grayscale',
         label_mode='binary')

print("Dataset walidacyjny")
val_dir = tf.keras.preprocessing.image_dataset_from_directory(
         '\\Users\\Karol\\UczenieMaszynowe\\chest_xray\\val\\', 
         image_size=target_size,
         batch_size=batch_size,
         color_mode='grayscale',
         label_mode='binary')

print("Dataset testowy")
test_dir = tf.keras.preprocessing.image_dataset_from_directory(
         '\\Users\\Karol\\UczenieMaszynowe\\chest_xray\\test', 
         image_size=target_size,
         batch_size=batch_size, 
         color_mode='grayscale',
         label_mode='binary')

Powyższe liczby przedstawiają sumaryczną liczbę plików w podkatalogach NORMAL i PNEUMONIA.

Przypomnę:

Doświadczenie to pokazuje, że danych treningowych jest odpowiednio:
1 341 dla płuc zdrowych
3 875 dla chorych.

Danych do testów jest 234/390, natomiast danych walidujących jest (niestety) tylko po 8.

Mamy więc prostą matematykę…

1341 + 3875 = 5216
234 + 390 = 624
8+8 = 16

… z której wynika, że szesnaście zdjęć danych walidacyjnych to niezbyt rozsądna wielkość.

Aby zwiększyć partię walidacji, najpierw połączmy zestawy danych z katalogu train i val a następnie podzielmy zestaw w stosunku 80% — 20%.

num_elements = tf.data.experimental.cardinality(train_dir).numpy()
print ("Partie treningowe:", num_elements)

num_elements_val = tf.data.experimental.cardinality(val_dir).numpy()
print ("Partie walidacyjne:",num_elements_val)

Widzimy, że są 82 partie treningowe i 1 partia walidacyjna.

Połączmy je (82 + 1 = 83) i podzielmy w stosunku 80%/20%.

new_train_ds = train_dir.concatenate(val_dir)
print (new_train_ds, train_dir)

train_size = int(0.8 * 83) 
val_size = int(0.2 * 83)
    
train_ds = new_train_ds.take(train_size)
val_ds = new_train_ds.skip(train_size).take(val_size)

num_elements_train = tf.data.experimental.cardinality(train_ds).numpy()
print ('num_elements_train:', num_elements_train)

num_elements_val_ds = tf.data.experimental.cardinality(val_ds).numpy()
print ('num_elements_val_ds:', num_elements_val_ds)

Dopiero teraz uzyskaliśmy stosunek danych treningowych i walidacyjnych na rozsądnym poziomie.

Dodajmy warstwę przeskalowującą i niektóre argumentacje również jako warstwy aby zostało to uwzględnione w modelu jako warstwy lambda:

freq_neg = tot_normal_train/(tot_normal_train + tot_pneumonia_train)
freq_pos = tot_pneumonia_train/(tot_normal_train + tot_pneumonia_train)

pos_weights = np.array([freq_neg])
neg_weights = np.array([freq_pos])

print ('check positive weight: ', pos_weights, len(pos_weights))
print ('check negative weight: ', neg_weights)


def get_weighted_loss(pos_weights, neg_weights, epsilon=1e-7):
    """
    Return weighted loss function given negative weights and positive weights.

    Args:
      pos_weights (np.array): array of positive weights for each class, size (num_classes)
      neg_weights (np.array): array of negative weights for each class, size (num_classes)
    
    Returns:
      weighted_loss (function): weighted loss function
    """
    def weighted_loss(y_true, y_pred):
        """
        Return weighted loss value. 

        Args:
            y_true (Tensor): Tensor of true labels, size is (num_examples, num_classes)
            y_pred (Tensor): Tensor of predicted labels, size is (num_examples, num_classes)
        Returns:
            loss (float): overall scalar loss summed across all classes
        """
        # initialize loss to zero
        loss = 0.0

        for i in range(len(pos_weights)): # we have only 1 class 
            # for each class, add average weighted loss for that class 
            loss += - (K.mean((pos_weights[i] * y_true[:, i] * K.log(y_pred[:, i] + epsilon)) + 
                              (neg_weights[i] * (1-y_true[:, i]) * K.log(1-y_pred[:, i] + epsilon)) ) )
        return loss
    return weighted_loss

Co to jest funkcja straty ważonej?

Funkcja straty ważonej rozwiązuje typowy problem niezrównoważonych danych. W naszym przypadku mamy dużo więcej zdjęć klatki piersiowej z zapaleniem płuc niż zdjęć płuc zdrowych. Jeśli taki zbiór podamy modelowi bez korekty, model zacznie niejako „faworyzować” klasę dominującą (czyli w naszym przypadku – pneumonia) i ignorować rzadką klasę (płuca zdrowe), bo taka strategia nadal daje mu niezłą accuracy.

Trenowanie modelu

Za proces trenowania modelu odpowiada poniższy kod….

start_time = time.time()
history = model.fit(train_data_batches, 
                    epochs=100, 
                    validation_data=valid_data_batches,
                    callbacks=[mcp_save, es, reduce_lr])

end_time = time.time()

…który na moim serwerze wykonywał się 20 godzin! Tak długi czas spowodowany był faktem, że serwer którym dysponuję pozbawiony jest procesora GPU, co zmusiło mnie do użycia wyłącznie CPU.

Po przeprowadzeniu (stanowczo zbyt długiego) trenowania, możemy przejść do testów…

plt.figure(figsize=(16, 12))
for images, labels in test_data_batches.take(1):
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        
        y_pred_batch = model.predict(tf.expand_dims(images[i], axis=0, ))
        y_pred_75th = (y_pred_batch > 0.75).astype(np.uint8)
        original_label = class_names[labels[i].numpy().astype("uint8")[0]]
        predicted_label = class_names[y_pred_75th[0].astype("uint8")[0]]
        plt.title(f'Oryg: {original_label}  ; Przewidywany: {predicted_label}  ')
        plt.axis("off") 

…z których wynika, że model prawidłowo zaklasyfikował poddane testom zdjęcia!.

Czyli innymi słowy na poprawnie przewidział czy podane zdjęcie przedstawia obraz płuc zdrowych, czy zainfekowanych, a takiego wyniku właśnie oczekiwałem! 😉

Źródła:

[1] Chest X-Ray Images (Pneumonia) https://www.kaggle.com/datasets/paultimothymooney/chest-xray-pneumonia

[2] Wikipedia Pneumonia chest x ray: https://www.wikidoc.org/index.php/Pneumonia_chest_x_ray

[3] Wikipedia Consolidation: https://www.wikidoc.org/index.php/Consolidation_(medicine)

[4] Materiały szkoleniowe machine-learning https://www.udemy.com/topic/tensorflow/

[5] Repozytoria:

https://www.kaggle.com/code/ahmedramsey/chest-x-ray

https://www.kaggle.com/code/razarizwanahmed/chest-x-ray-pneumonia-classification-0-98-acc

https://www.kaggle.com/code/suvoooo/detectpneumonia-inceptionresnetv2-class-imbalance

https://www.kaggle.com/datasets/paultimothymooney/chest-xray-pneumonia/code

[6] Chest X-ray & Pneumonia: Deep Learning with TensorFlow Saptashwa Bhattacharyya https://towardsdatascience.com/chest-x-ray-pneumonia-deep-learning-with-tensorflow-a58a9e6ade70

[7] Visualization and Interpretation of Convolutional Neural Network Predictions in Detecting Pneumonia in Pediatric Chest Radiographs https://www.mdpi.com/2076-3417/8/10/1715/htm