Obsah
1. Framework Torch: problematika rozpoznávání a klasifikace obrázků
2. První verze generátoru trénovacích obrázků: číslic od 0 do 9
3. Ukázka vygenerovaných obrázků
4. Jednoduché zašumění trénovacích obrázků
6. Použití klasické neuronové sítě pro rozpoznání číslic v obrázku
8. Grafické zobrazení klasifikace verifikačních obrázků
9. Ukázka výsledků odhadu sítě
10. Vliv postupného zvyšování šumu
11. Úplný zdrojový kód prvního příkladu
13. Verifikace sítě s posunutými obrázky
14. Ukázka výsledků odhadu sítě pro posunuté obrázky
15. Zdrojový kód druhého příkladu
16. Vylepšení architektury neuronových sítí aneb konvoluční sítě
17. Vrstvy v konvolučních sítích
18. Repositář s demonstračními příklady
1. Framework Torch: problematika rozpoznávání a klasifikace obrázků
Na předchozí dvě části [1] [2] seriálu o frameworku Torch, v nichž jsme se seznámili s postupem, který se používá při tvorbě umělých neuronových sítí s pravidelnou strukturou tvořenou jednotlivými vrstvami, u nichž učení probíhá s využitím takzvaného backpropagation algoritmu (algoritmu zpětného šíření), dnes navážeme. Budeme se totiž zabývat problematikou rozpoznávání a klasifikace rastrových obrázků, které sice budou zpočátku velmi malé a budou obsahovat poměrně dobře predikovatelná data, ovšem i na takto malých obrázcích si ukážeme některé nevýhody klasických obecných neuronových sítí při jejich aplikaci na rastrová data.
V závěru článku se navíc seznámíme s principy, na nichž jsou postaveny takzvané konvoluční neuronové sítě. Ty jsou dnes velmi populární, a to hned z několika důvodů – po natrénování sítě (to je sice časově náročné, ovšem s moderními GPU již většinou uspokojivě řešitelné) jsou již konvoluční sítě poměrně rychlé a především se rozšiřují možnosti, kde je možné tyto sítě prakticky použít (doprava, průmysl atd.).
2. První verze generátoru trénovacích obrázků: číslic od 0 do 9
Jak jsme si již řekli v úvodním odstavci, budeme se dnes snažit s využitím jednoduchých neuronových sítí rozpoznávat objekty na velmi malých obrázcích. Konkrétně se bude jednat o vstupní obrázky s pevným rozlišením pouhých 8×8 pixelů, což nám mj. umožní velmi rychlý tréning sítě. Obrázky budou reprezentovány ve stupních šedi a úkolem postupně vytvářené neuronové sítě bude na těchto obrázcích rozpoznat číslice 0 až 9 zapsané pro jednoduchost předem známým fontem (příště už budeme mít horší úkol, protože číslice budou napsány rukou). Abychom získali představu, jak tyto číslice vypadají, necháme si vygenerovat testovací obrázky, a to z následujících vstupních dat:
digits = { {0x00, 0x3C, 0x66, 0x76, 0x6E, 0x66, 0x3C, 0x00}, {0x00, 0x18, 0x1C, 0x18, 0x18, 0x18, 0x7E, 0x00}, {0x00, 0x3C, 0x66, 0x30, 0x18, 0x0C, 0x7E, 0x00}, {0x00, 0x7E, 0x30, 0x18, 0x30, 0x66, 0x3C, 0x00}, {0x00, 0x30, 0x38, 0x3C, 0x36, 0x7E, 0x30, 0x00}, {0x00, 0x7E, 0x06, 0x3E, 0x60, 0x66, 0x3C, 0x00}, {0x00, 0x3C, 0x06, 0x3E, 0x66, 0x66, 0x3C, 0x00}, {0x00, 0x7E, 0x60, 0x30, 0x18, 0x0C, 0x0C, 0x00}, {0x00, 0x3C, 0x66, 0x3C, 0x66, 0x66, 0x3C, 0x00}, {0x00, 0x3C, 0x66, 0x7C, 0x60, 0x30, 0x1C, 0x00}, }
Každá číslice, jejíž tvar je zakódován v poli digits, je reprezentována osmicí bajtů, protože každý bajt reprezentuje osm sousedních pixelů. Celkem tedy vstupní data obsahují osmdesát bajtů (deset číslic × osm bajtů).
Rastrové podoby jednotlivých číslic se uloží do externích souborů s využitím formátu PGM (Portable GrayMap), který již na stránkách Rootu byl poměrně podrobně popsán. Následující funkce se postará a vytvoření rastrového obrázku ze vstupních dat (používá se výše zmíněné pole digits). Povšimněte si, že využíváme jedné vlastnosti specifické pro PGM – uvedeme, že maximální hodnota pixelu je rovna 1, tudíž se nemusíme starat o převod jeho světlosti do rozsahu 0..255:
function generate_exact_image(filename, digit) if digit < 0 or digit > 9 then return end codes = digits[digit+1] local fout = io.open(filename, "w") if not fout then return end -- hlavicka fout:write("P2\n8 8\n1\n") for _, code in ipairs(codes) do -- pouze pro ladeni local s = "" for i = 1,8 do local bit = code % 2 fout:write(bit) fout:write(" ") -- pouze pro ladeni s = s .. bit code = (code - bit)/2 end -- pouze pro ladeni print(s) end print() fout:close() end
Příklad jednoduché bitmapy o rozměrech 8×8 pixelů. Bitmapa obsahuje tvar číslice 0:
P2 8 8 1 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 1 1 0 0 1 1 0 0 1 1 0 1 1 1 0 0 1 1 1 0 1 1 0 0 1 1 0 0 1 1 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 0
Poznámka: ve skutečnosti se v datech bitmapy nemusí používat konce řádků (ty „jen“ zvyšují čitelnost pro člověka), takže je možný i tento formát:
P2 8 8 1 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 1 1 0 0 1 1 0 0 1 1 0 1 1 1 0 0 1 1 1 0 1 1 0 0 1 1 0 0 1 1 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 0
3. Ukázka vygenerovaných obrázků
Rastrové obrázky s tvary jednotlivých číslic již dokážeme vygenerovat triviálním způsobem:
for digit = 0, 9 do local filename = string.format("digit_%d.pgm", digit) generate_exact_image(filename, digit) end
Podívejme se nyní na vygenerované výsledky, tj. na bitmapy s číslicemi. Všechny obrázky byly v horizontálním i vertikálním směru zvětšeny dvacetkrát, takže se z mini-bitmap o rozměrech 8×8 pixelů staly už dobře rozpoznatelné bitmapy o rozměrech 160×160 pixelů:
Obrázek 1: Tvar číslice 0.
Obrázek 2: Tvar číslice 1.
Obrázek 3: Tvar číslice 2.
Obrázek 4: Tvar číslice 3.
Obrázek 5: Tvar číslice 4.
Obrázek 6: Tvar číslice 5.
Obrázek 7: Tvar číslice 6.
Obrázek 8: Tvar číslice 7.
Obrázek 9: Tvar číslice 8.
Obrázek 10: Tvar číslice 9.
4. Jednoduché zašumění trénovacích obrázků
Po natrénování neuronové sítě s využitím pouhých deseti vstupních obrázků by se mohlo stát, že by síť prakticky vůbec nebyla schopna rozeznat i nepatrně změněná vstupní data. Proto funkci pro vytvoření trénovacích obrázků vhodně pozměníme takovým způsobem, že se do obrázků zanese šum. Pro vytvoření šumu používám standardní funkci math.random, ovšem v případě potřeby samozřejmě můžete využít i funkci pro generování náhodných hodnot s normálním rozložením, která je součástí samotné knihovny Torch a kterou jsme již použili pro trénink předchozích neuronových sítí. Povšimněte si dále, že i zašuměné obrázky mají přesně stanovenou hranici mezi pixely, které tvoří číslici a pixely tvořícími pozadí. Tuto část si samozřejmě můžete upravit, a to i takovým způsobem, aby tato hranice byla z obou stran překračována (ovšem takto obecně naučená síť nebude dávat jednoznačné výsledky – ostatně si to sami vyzkoušejte):
function generate_training_image(filename, digit, noise_amount) if digit < 0 or digit > 9 then return end codes = digits[digit+1] local fout = io.open(filename, "w") if not fout then return end -- hlavicka fout:write("P2\n8 8\n255\n") for _, code in ipairs(codes) do for i = 1,8 do local bit = code % 2 fout:write(192*bit + math.random(0,noise_amount)) fout:write(" ") s = s .. bit code = (code - bit)/2 end end fout:close() end
Nyní si již můžeme vygenerovat libovolnou sekvenci trénovacích obrázků. Pokud budete potřebovat víc obrázků, stačí změnit pole NOISE či celočíselnou konstantu REPEAT_COUNT:
for _, noise in ipairs(NOISE) do for digit = 0, 9 do for i = 1, REPEAT_COUNT do local filename = string.format("%d_%d_%d.pgm", digit, noise, i) generate_image(filename, digit, noise) end end end
Poznámka: v praxi se nebudeme zdržovat generováním souborů s obrázky, ale budeme tvořit přímo tenzory určené pro vstup do neuronové sítě. Výše zmíněná funkce generate_training_image ovšem může dobře posloužit pro vizualizaci vstupů sítě.
5. Ukázka zašuměných obrázků
Opět se pro zajímavost podívejme na obrázky vykreslené funkcí generate_training_image, zde konkrétně pro číslici 2. Pokud vám připadne, že první obrázek obsahuje jen dvě barvy, zkuste si ho otevřít v grafickém editoru či v grafickém prohlížeči a následně si zobrazit jeho histogram (celkem se používá 29 různých odstínů šedi v 64 pixelech):
Obrázek 11: Nepatrně zašuměná číslice 2.
Obrázek 12: Zvýšení míry šumu.
Obrázek 13: Další zvýšení míry šumu, tentokrát již jasně viditelné.
6. Použití klasické neuronové sítě pro rozpoznání číslic v obrázku
Jak připravit trénovací data již víme, takže nám již zbývá „maličkost“ – pokusit se vytvořit vhodnou neuronovou sít pro rozpoznávání číslic na obrázcích. Některé parametry sítě přitom předem vyplývají z podstaty řešeného problému, další parametry pouze odhadneme:
- Vstupů sítě bude 64, protože budeme zpracovávat bitmapy s 8×8 pixely.
- Výstupů sítě bude 10, přičemž každý výstup bude určovat, do jaké míry síť odhadla, že obrázek obsahuje danou číslici. Na výstup se můžeme dívat jako na vektor, kde v ideálním případě bude devět prvků nulových a jeden prvek bude obsahovat hodnotu 1. Index tohoto prvku pak přímo určuje hodnotu nalezené číslice (pokud indexujeme od nuly). V praxi se budou výstupy obsahovat i jiné hodnoty, ale vždy by mělo být možné najít jeden prvek s výrazně větší hodnotou.
- Pro jednoduchost bude síť obsahovat jen jednu skrytou vrstvu, v níž bude 100 neuronů (více, než na vstupní vrstvě).
Parametry neuronové sítě:
INPUT_NEURONS = 64 HIDDEN_NEURONS = 100 OUTPUT_NEURONS = 10
Parametry pro učení neuronové sítě:
MAX_ITERATION = 200 LEARNING_RATE = 0.01
V síti použijeme nelineární aktivační funkce Tanh, takže celá struktura sítě bude vypadat následovně:
nn.Sequential { [input -> (1) -> (2) -> (3) -> (4) -> output] (1): nn.Linear(64 -> 100) (2): nn.Tanh (3): nn.Linear(100 -> 10) (4): nn.Tanh }
Funkce pro vytvoření trénovacích dat musí připravit sérii tenzorů představujících vstup do sítě. Použijeme přitom upravenou funkci generate_image_data, která již nebude vytvářet obrázek 8×8 pixelů, ale tenzor s 64 komponentami:
function generate_image_data(digit, noise_amount) codes = digits[digit+1] local index = 1 local result = torch.Tensor(8*8) for _, code in ipairs(codes) do for i = 1,8 do local bit = code % 2 local value = 192*bit + math.random(0,noise_amount) result[index] = value index = index + 1 code = (code - bit)/2 end end return result end function prepare_training_data() local training_data_size = #NOISE * REPEAT_COUNT * DIGITS local training_data = {} function training_data:size() return training_data_size end local index = 1 for _, noise_amount in ipairs(NOISE) do for digit = 0, 9 do for i = 1, REPEAT_COUNT do local input = generate_image_data(digit, noise_amount) local output = generate_expected_output(digit) training_data[index] = {input, output} index = index + 1 end end end return training_data end
Poznámka: zápis #NOISE v jazyce Lua znamená, že se vrátí počet prvků v poli NOISE.
7. Průběh tréninku sítě
Podívejme se nyní, jak vypadá trénink sítě. Chyba postupně klesá, a to zpočátku dosti výrazně. To může znamenat dvě věci: buď se nám podařilo navrhnout vhodnou strukturu sítě, nebo nejsou trénovací data dostatečně variabilní a síť proto nebude dostatečně adaptována pro reálná data (uvidíme dále):
# StochasticGradient: training # current error = 0.09230117269206 # current error = 0.032254759333561 # current error = 0.018699736965155 # current error = 0.013066257387987 # current error = 0.009324769702847 # current error = 0.0085818671242583 # current error = 0.0065817888003809 # current error = 0.0051915683049868 # current error = 0.0043908250531581 # current error = 0.0037979578861573 ... ... ... # current error = 0.00011141314090753 # current error = 0.00011082007280134 # StochasticGradient: you have reached the maximum number of iterations # training error = 0.00011082007280134
Výsledná chyba je již dostatečně nízká, takže si můžeme naši sít verifikovat.
8. Grafické zobrazení klasifikace verifikačních obrázků
Zatímco u předchozích neuronových sítí nám stačilo si vypsat odhadovanou hodnotu, porovnat ji s hodnotou očekávanou a následně vypočítat chybu, u dnešní sítě zvolíme jiný postup. Necháme si vykreslit graf, který pro různé vstupní obrázky zobrazí všech deset odhadů číslic. Přitom by jeden odhad měl výrazně převyšovat ostatní odhady. Aby se mohl vykreslit 2D graf, je nutné vytvořit 2D tenzor s výsledky. Tenzor bude mít rozměry počet_odhadů×počet_rozpoznávaných_číslic, kde počet_rozpoznávaných_číslic je roven deseti. Tenzor před vykreslením transponujeme metodou t(). Validace tedy může vypadat takto:
function validate_neural_network(network, digit, noise_amount) local data_size = 100 local values = torch.Tensor(data_size, DIGITS) for i = 1, data_size do local input = generate_image_data(digit, noise_amount) local output = network:forward(input) values[i] = output --print(output) end local filename = string.format("digit%d_noise%d.png", digit, noise_amount) plot_graph(filename, values:t()) end
Funkce, která vykreslí 2D graf s odhady sítě, je velmi jednoduchá, protože jí již předáváme 2D tenzor, jehož hodnoty se bez dalších úprav vynesou do grafu ve formě barev:
function plot_graph(filename, values) gnuplot.pngfigure(filename) gnuplot.imagesc(values, 'color') gnuplot.plotflush() gnuplot.close() end
9. Ukázka výsledků odhadu sítě
Podívejme se nyní na vytvořené grafy. Na horizontální osu jsou vynesena čísla jednotlivých měření (bylo jich celkem provedeno sto), na osu vertikální přímo hodnoty číslic. Podívejme se na první graf, na němž je konstantní plocha představující nuly a jediný žlutý pruh představující váhu 1. Neuronová sít pro nezašuměný obrázek s číslicí 1 vždy na 100% tuto číslici odhadla:
Obrázek 14: Vstupem je obrázek s číslicí 1. Míra šumu je nastavena na 0.
Na druhém grafu je výsledek odhadu sítě pro nepatrně zašuměné obrázky, konkrétně pro obrázky, v nichž se hodnoty „černých“ pixelů pohybují v rozsahu 0..15 a hodnoty pixelů „bílých“ v rozsahu 192..207. Zde je patrné, že u minimálně dvou vstupních obrázků si síť nebyla na 100% jistá výsledkem:
Obrázek 15: Vstupem je obrázek s číslicí 1. Míra šumu je nastavena na 15.
Čím větší je šum zanesený do obrázku, tím méně jistoty nalezneme u odhadu sítě.
Obrázek 16: Vstupem je obrázek s číslicí 1. Míra šumu je nastavena na 30.
Obrázek 17: Vstupem je obrázek s číslicí 1. Míra šumu je nastavena na 45.
I pro hodně zašuměné obrázky (přitom šum přesahuje míru použitou při tréninku sítě!) stále dostáváme použitelné výsledky, i když ne tak jednoznačné.
Obrázek 18: Vstupem je obrázek s číslicí 1. Míra šumu je nastavena na 60.
Podobné výsledky, ovšem pro vstupní obrázky s číslicí 3:
Obrázek 19: Vstupem je obrázek s číslicí 3. Míra šumu je nastavena na 0.
Obrázek 20: Vstupem je obrázek s číslicí 3. Míra šumu je nastavena na 15.
Obrázek 21: Vstupem je obrázek s číslicí 3. Míra šumu je nastavena na 30.
Obrázek 22: Vstupem je obrázek s číslicí 3. Míra šumu je nastavena na 45.
Obrázek 23: Vstupem je obrázek s číslicí 3. Míra šumu je nastavena na 60.
10. Vliv postupného zvyšování šumu
Zajímavé bude sledovat, jak se bude odhad sítě zhoršovat s rostoucím šumem. Proto si vytvoříme novou funkci validate_neural_network_variable_noise, v níž vykreslíme podobné grafy pro verifikační data, ovšem nyní se bude s každým měřením zvětšovat míra šumu až na hodnotu 64 (tj. pixely „bílé“ a „černé“ ve skutečnosti mohou nabývat jedné z 64 hodnot). Funkce vypadá takto:
function validate_neural_network_variable_noise(network, digit) local data_size = 64 local values = torch.Tensor(data_size, DIGITS) for noise_amount = 0, data_size-1 do local input = generate_image_data(digit, noise_amount) local output = network:forward(input) values[noise_amount+1] = output end local filename = string.format("digit%d_variable_noise.png", digit) plot_graph(filename, values:t()) end
Výsledkem je pouhých deset grafů pro deset číslic, takže si je uvedeme všechny. Zajímavé je zjištění, že při zvětšujícím se zašumění se jistota sítě v odhadu číslice liší podle toho, jaký tvar je rozpoznáván:
Obrázek 24: Vstupem jsou obrázky s číslicí 0. Nejpodobnější jsou číslice 6 a 8.
Obrázek 25: Vstupem jsou obrázky s číslicí 1.
Obrázek 26: Vstupem jsou obrázky s číslicí 2. Nejpodobnější je číslice 3.
Obrázek 27: Vstupem jsou obrázky s číslicí 3. Nejpodobnější je osmička.
Obrázek 28: Vstupem jsou obrázky s číslicí 4.
Obrázek 29: Vstupem jsou obrázky s číslicí 5.
Obrázek 30: Vstupem jsou obrázky s číslicí 6.
Obrázek 31: Vstupem jsou obrázky s číslicí 7.
Obrázek 32: Vstupem jsou obrázky s číslicí 8.
Obrázek 33: Vstupem jsou obrázky s číslicí 9.
11. Úplný zdrojový kód prvního příkladu
Pod tímto odstavcem je vypsán úplný zdrojový kód dnešního prvního demonstračního příkladu s jednoduchou sítí se třemi vrstvami, která rozpoznává číslice napsané předem známým fontem, přičemž obrázky s číslicemi mohou být do určité míry zašuměny. Zdrojový kód najdete i na adrese https://github.com/tisnik/torch-examples/blob/master/nn/bitmapnn/01_noisy_images.lua:
require("nn") require("gnuplot") -- parametry neuronove site INPUT_NEURONS = 64 HIDDEN_NEURONS = 100 OUTPUT_NEURONS = 10 -- parametry pro uceni neuronove site MAX_ITERATION = 200 LEARNING_RATE = 0.01 NOISE = {0, 8, 16}--, 32} REPEAT_COUNT = 5 DIGITS = 10 digits = { {0x00, 0x3C, 0x66, 0x76, 0x6E, 0x66, 0x3C, 0x00 }, {0x00, 0x18, 0x1C, 0x18, 0x18, 0x18, 0x7E, 0x00 }, {0x00, 0x3C, 0x66, 0x30, 0x18, 0x0C, 0x7E, 0x00 }, {0x00, 0x7E, 0x30, 0x18, 0x30, 0x66, 0x3C, 0x00 }, {0x00, 0x30, 0x38, 0x3C, 0x36, 0x7E, 0x30, 0x00 }, {0x00, 0x7E, 0x06, 0x3E, 0x60, 0x66, 0x3C, 0x00 }, {0x00, 0x3C, 0x06, 0x3E, 0x66, 0x66, 0x3C, 0x00 }, {0x00, 0x7E, 0x60, 0x30, 0x18, 0x0C, 0x0C, 0x00 }, {0x00, 0x3C, 0x66, 0x3C, 0x66, 0x66, 0x3C, 0x00 }, {0x00, 0x3C, 0x66, 0x7C, 0x60, 0x30, 0x1C, 0x00 }, } function generate_image_data(digit, noise_amount) codes = digits[digit+1] local index = 1 local result = torch.Tensor(8*8) for _, code in ipairs(codes) do for i = 1,8 do local bit = code % 2 local value = 192*bit + math.random(0,noise_amount) result[index] = value index = index + 1 code = (code - bit)/2 end end return result end function generate_expected_output(digit) local result = torch.zeros(DIGITS) result[digit+1] = 1 return result end function prepare_training_data() local training_data_size = #NOISE * REPEAT_COUNT * DIGITS local training_data = {} function training_data:size() return training_data_size end local index = 1 for _, noise_amount in ipairs(NOISE) do for digit = 0, 9 do for i = 1, REPEAT_COUNT do local input = generate_image_data(digit, noise_amount) local output = generate_expected_output(digit) training_data[index] = {input, output} index = index + 1 end end end return training_data end function construct_neural_network(input_neurons, hidden_neurons, output_neurons) local network = nn.Sequential() network:add(nn.Linear(input_neurons, hidden_neurons)) network:add(nn.Tanh()) network:add(nn.Linear(hidden_neurons, output_neurons)) -- pridana nelinearni funkce network:add(nn.Tanh()) return network end function train_neural_network(network, training_data, learning_rate, max_iteration) local criterion = nn.MSECriterion() local trainer = nn.StochasticGradient(network, criterion) trainer.learningRate = learning_rate trainer.maxIteration = max_iteration trainer:train(training_data) end function plot_graph(filename, values) gnuplot.pngfigure(filename) gnuplot.imagesc(values, 'color') gnuplot.plotflush() gnuplot.close() end function validate_neural_network(network, digit, noise_amount) local data_size = 100 local values = torch.Tensor(data_size, DIGITS) for i = 1, data_size do local input = generate_image_data(digit, noise_amount) local output = network:forward(input) values[i] = output --print(output) end local filename = string.format("digit%d_noise%d.png", digit, noise_amount) plot_graph(filename, values:t()) end function validate_neural_network_variable_noise(network, digit) local data_size = 64 local values = torch.Tensor(data_size, DIGITS) for noise_amount = 0, data_size-1 do local input = generate_image_data(digit, noise_amount) local output = network:forward(input) values[noise_amount+1] = output end local filename = string.format("digit%d_variable_noise.png", digit) plot_graph(filename, values:t()) end network = construct_neural_network(INPUT_NEURONS, HIDDEN_NEURONS, OUTPUT_NEURONS) print(network) training_data = prepare_training_data() train_neural_network(network, training_data, LEARNING_RATE, MAX_ITERATION) for digit = 0, 9 do validate_neural_network_variable_noise(network, digit) end for noise = 0, 60, 15 do validate_neural_network(network, 1, noise) validate_neural_network(network, 3, noise) validate_neural_network(network, 8, noise) end
12. Kde je tedy problém?
Výsledky uvedené v kapitole 9 a 10 zdánlivě naznačují, že je naše neuronová síť velmi úspěšná v rozpoznávání obrázků číslic. Ve skutečnosti je však nutné přiznat, že to vlastně vůbec není pravda, a to minimálně ze dvou důvodů:
- Síť dokáže rozpoznat pouze jeden font, což obecně bude vadit, například ve chvíli, kdy namísto námi připravených trénovacích dat použijeme například ručně psané číslice z již zmíněné databáze MNIST. A raději ji vůbec nepouštějte na obrázky získané ze systémů CAPTCHA :-)
- Síť je možné velmi snadno zmást i při použití stále stejného fontu. Postačuje pouze obraz číslice posunout o jeden jediný pixel (jakýmkoli směrem)!
13. Verifikace sítě s posunutými obrázky
Ukažme si, zda platí druhé tvrzení. Nepatrně upravíme funkci pro generování trénovacích a/nebo verifikačních dat tak, aby bylo možné obrázek vertikálně posunout, a to jak nahoru, tak i dolů o zadaný offset (samozřejmě si můžete provést úpravu i pro posun doprava a doleva):
function generate_image_data(digit, noise_amount, offset_y) local max_index = 8*8 codes = digits[digit+1] local index = 1 - 8*offset_y local result = torch.zeros(max_index) for _, code in ipairs(codes) do for i = 1,8 do local bit = code % 2 local value = 192*bit + math.random(0,noise_amount) if index >= 1 and index <= max_index then result[index] = value end index = index + 1 code = (code - bit)/2 end end return result end
14. Ukázka výsledků odhadu sítě pro posunuté obrázky
Výsledky si opět zobrazíme formou 2D grafu. Nejdříve pro odhady sítě pro obrázky s číslicí 1, které jsou postupně stále více zašuměny. Vidíme, že výsledky jsou stále dobré, v souladu s očekáváním:
Obrázek 34: Vstupem jsou obrázky číslice 1. Šum postupně roste od 0 do 64.
Posun obrazu číslice o jeden řádek ovšem síť dokonale zmate a výsledky přestanou být použitelné:
Obrázek 35: Vstupem jsou obrázky číslice 1 posunuté o jeden obrazový řádek. Šum postupně roste od 0 do 64.
Totéž platí i pro posun o řádek, ovšem druhým směrem:
Obrázek 36: Vstupem jsou obrázky číslice 1 posunuté o jeden obrazový řádek. Šum postupně roste od 0 do 64.
Tentýž odhad můžeme provést pro obrázky s číslicí 3:
Obrázek 37: Vstupem jsou obrázky číslice 3. Šum postupně roste od 0 do 64.
Obrázek 38: Vstupem jsou obrázky číslice 3 posunuté o jeden obrazový řádek. Šum postupně roste od 0 do 64.
Obrázek 39: Vstupem jsou obrázky číslice 3 posunuté o jeden obrazový řádek. Šum postupně roste od 0 do 64.
Poslední série odhadů, tentokrát při číslici osm:
Obrázek 40: Vstupem jsou obrázky číslice 8. Šum postupně roste od 0 do 64.
Obrázek 41: Vstupem jsou obrázky číslice 8 posunuté o jeden obrazový řádek. Šum postupně roste od 0 do 64.
Obrázek 42: Vstupem jsou obrázky číslice 8 posunuté o jeden obrazový řádek. Šum postupně roste od 0 do 64.
15. Zdrojový kód druhého příkladu
Opět se podívejme na úplný zdrojový kód dnešního druhého a současně i posledního demonstračního příkladu, v němž se (neúspěšně) snažíme rozpoznat číslice, které jsou v obrázku posunuty o jeden řádek nahoru a dolů:
require("nn") require("gnuplot") -- parametry neuronove site INPUT_NEURONS = 64 HIDDEN_NEURONS = 100 OUTPUT_NEURONS = 10 -- parametry pro uceni neuronove site MAX_ITERATION = 200 LEARNING_RATE = 0.01 NOISE = {0, 8, 16}--, 32} REPEAT_COUNT = 5 DIGITS = 10 digits = { {0x00, 0x3C, 0x66, 0x76, 0x6E, 0x66, 0x3C, 0x00 }, {0x00, 0x18, 0x1C, 0x18, 0x18, 0x18, 0x7E, 0x00 }, {0x00, 0x3C, 0x66, 0x30, 0x18, 0x0C, 0x7E, 0x00 }, {0x00, 0x7E, 0x30, 0x18, 0x30, 0x66, 0x3C, 0x00 }, {0x00, 0x30, 0x38, 0x3C, 0x36, 0x7E, 0x30, 0x00 }, {0x00, 0x7E, 0x06, 0x3E, 0x60, 0x66, 0x3C, 0x00 }, {0x00, 0x3C, 0x06, 0x3E, 0x66, 0x66, 0x3C, 0x00 }, {0x00, 0x7E, 0x60, 0x30, 0x18, 0x0C, 0x0C, 0x00 }, {0x00, 0x3C, 0x66, 0x3C, 0x66, 0x66, 0x3C, 0x00 }, {0x00, 0x3C, 0x66, 0x7C, 0x60, 0x30, 0x1C, 0x00 }, } function generate_image_data(digit, noise_amount, offset_y) local max_index = 8*8 codes = digits[digit+1] local index = 1 - 8*offset_y local result = torch.zeros(max_index) for _, code in ipairs(codes) do for i = 1,8 do local bit = code % 2 local value = 192*bit + math.random(0,noise_amount) if index >= 1 and index <= max_index then result[index] = value end index = index + 1 code = (code - bit)/2 end end return result end function generate_expected_output(digit) local result = torch.zeros(DIGITS) result[digit+1] = 1 return result end function prepare_training_data() local training_data_size = #NOISE * REPEAT_COUNT * DIGITS local training_data = {} function training_data:size() return training_data_size end local index = 1 for _, noise_amount in ipairs(NOISE) do for digit = 0, 9 do for i = 1, REPEAT_COUNT do local input = generate_image_data(digit, noise_amount, 0) local output = generate_expected_output(digit) training_data[index] = {input, output} index = index + 1 end end end return training_data end function construct_neural_network(input_neurons, hidden_neurons, output_neurons) local network = nn.Sequential() network:add(nn.Linear(input_neurons, hidden_neurons)) network:add(nn.Tanh()) network:add(nn.Linear(hidden_neurons, output_neurons)) -- pridana nelinearni funkce network:add(nn.Tanh()) return network end function train_neural_network(network, training_data, learning_rate, max_iteration) local criterion = nn.MSECriterion() local trainer = nn.StochasticGradient(network, criterion) trainer.learningRate = learning_rate trainer.maxIteration = max_iteration trainer:train(training_data) end function plot_graph(filename, values) gnuplot.pngfigure(filename) gnuplot.imagesc(values, 'color') gnuplot.plotflush() gnuplot.close() end function validate_neural_network(network, digit, offset) local values = torch.Tensor(64, DIGITS) for noise_amount = 0, 63 do local input = generate_image_data(digit, noise_amount, offset) local output = network:forward(input) values[noise_amount+1] = output end local filename = string.format("digit%d_offset%d.png", digit, offset) plot_graph(filename, values:t()) end network = construct_neural_network(INPUT_NEURONS, HIDDEN_NEURONS, OUTPUT_NEURONS) print(network) training_data = prepare_training_data() train_neural_network(network, training_data, LEARNING_RATE, MAX_ITERATION) for offset = -1, 1 do validate_neural_network(network, 1, offset) validate_neural_network(network, 3, offset) validate_neural_network(network, 8, offset) end
16. Vylepšení architektury neuronových sítí aneb konvoluční sítě
Jak je tedy možné zlepšit odhad sítě i v případě, že očekáváme, že obrázky budou posunuty, nepatrně otočeny, zkoseny atd.? Máme k dispozici více řešení. Buď udělat síť mnohem víc robustní, což znamená výrazně zvětšit počet skrytých vrstev, zvětšit počet neuronů v těchto vrstvách a o několik řádů zvětšit i množství trénovacích dat (různé formy offsetu, posun jen některých pixelů atd.). To je sice skutečně možné zařídit (ostatně zaplatíme za to „jen“ strojovým časem), ovšem stále zde narážíme na principiální omezení klasických vrstvených neuronových sítí – jednotlivé neurony se učí izolovaně od ostatních neuronů, zatímco na vstupu máme „plovoucí“ obrázek. Bylo by tedy výhodnější se zaměřit na vylepšení samotné architektury neuronové sítě specializované právě na to, že na vstupu bude mít bitmapy a tudíž by sousední neurony měly nějakým způsobem sdílet své váhy na vstupech. Taková architektura již ve skutečnosti byla dávno vymyšlena a jmenuje se konvoluční neuronová sít.
Ovšem stále musíme mít na paměti, že i konvoluční neuronové sítě jsou založené na klasických dopředných sítích, které navíc bývají tzv. hluboké.
17. Vrstvy v konvolučních sítích
V konvolučních sítích se používají vrstvy se speciálním významem i chováním. Jedná se především o takzvané konvoluční vrstvy, které jsou napojeny přímo na vstupní vrstvu popř. na subsamplingové vrstvy. Konvoluční vrstvy se skládají z obecně libovolného množství příznakových map, podle toho, jaké objekty nebo vlastnosti vlastně v obrázku rozpoznáváme. Zpracovávaná bitmapa se zde rozděluje na podoblasti, které se vzájemně překrývají. Neurony přitom mohou sdílet své váhy přiřazené vstupům. Jak přesně to funguje si řekneme příště. Mezi jednotlivé konvoluční vrstvy se vkládají subsamplingové vrstvy, které jsou z výpočetního hlediska jednodušší, protože neurony zde obsahují jen dvě váhy (součet vstupů+práh). Tyto vrstvy získaly svoje jméno podle toho, že umožňují provádět podvzorkování založené většinou na velmi jednoduchých funkcích aplikovaných na okolí každého pixelu (maximální hodnota, střední hodnota…).
Typicky se vrstvy střídají takto:
- Vstupní vrstva
- Konvoluční vrstva #1
- Subsamplingová vrstva #1
- Konvoluční vrstva #2
- Subsamplingová vrstva #2
- …
- …
- Klasická skrytá vrstva
- Výstupní vrstva
Existují ovšem i další možnosti, opět se o nich zmíníme příště.
18. Repositář s demonstračními příklady
Všechny demonstrační příklady, které jsme si popsali v předchozích kapitolách najdete v GIT repositáři dostupném na adrese https://github.com/tisnik/torch-examples.git. Následují odkazy na zdrojové kódy jednotlivých příkladů:
Příklad | Adresa |
---|---|
make_training_images.lua | https://github.com/tisnik/torch-examples/blob/master/nn/bitmapnn/make_training_images.lua |
01_noisy_images.lua | https://github.com/tisnik/torch-examples/blob/master/nn/bitmapnn/01_noisy_images.lua |
02_offset_images.lua | https://github.com/tisnik/torch-examples/blob/master/nn/bitmapnn/02_offset_images.lua |
Poznámka: první skript je možné spouštět přímo z interpretru jazyka Lua, není tedy nutné používat framework Torch.
19. Odkazy na Internetu
- THE MNIST DATABASE of handwritten digits
http://yann.lecun.com/exdb/mnist/ - MNIST database (Wikipedia)
https://en.wikipedia.org/wiki/MNIST_database - MNIST For ML Beginners
https://www.tensorflow.org/get_started/mnist/beginners - Stránka projektu Torch
http://torch.ch/ - Torch: Serialization
https://github.com/torch/torch7/blob/master/doc/serialization.md - Torch: modul image
https://github.com/torch/image/blob/master/README.md - Data pro neuronové sítě
http://archive.ics.uci.edu/ml/index.php - LED Display Domain Data Set
http://archive.ics.uci.edu/ml/datasets/LED+Display+Domain - Torch na GitHubu (několik repositářů)
https://github.com/torch - Torch (machine learning), Wikipedia
https://en.wikipedia.org/wiki/Torch_%28machine_learning%29 - Torch Package Reference Manual
https://github.com/torch/torch7/blob/master/README.md - Torch Cheatsheet
https://github.com/torch/torch7/wiki/Cheatsheet - Neural network containres (Torch)
https://github.com/torch/nn/blob/master/doc/containers.md - Simple layers
https://github.com/torch/nn/blob/master/doc/simple.md#nn.Linear - Transfer Function Layers
https://github.com/torch/nn/blob/master/doc/transfer.md#nn.transfer.dok - Feedforward neural network
https://en.wikipedia.org/wiki/Feedforward_neural_network - Biologické algoritmy (4) – Neuronové sítě
https://www.root.cz/clanky/biologicke-algoritmy-4-neuronove-site/ - Biologické algoritmy (5) – Neuronové sítě
https://www.root.cz/clanky/biologicke-algoritmy-5-neuronove-site/ - Umělá neuronová síť (Wikipedia)
https://cs.wikipedia.org/wiki/Um%C4%9Bl%C3%A1_neuronov%C3%A1_s%C3%AD%C5%A5 - Učení s učitelem (Wikipedia)
https://cs.wikipedia.org/wiki/U%C4%8Den%C3%AD_s_u%C4%8Ditelem - Plotting with Torch7
http://www.lighting-torch.com/2015/08/24/plotting-with-torch7/ - Plotting Package Manual with Gnuplot
https://github.com/torch/gnuplot/blob/master/README.md - An Introduction to Tensors
https://math.stackexchange.com/questions/10282/an-introduction-to-tensors - Gaussian filter
https://en.wikipedia.org/wiki/Gaussian_filter - Gaussian function
https://en.wikipedia.org/wiki/Gaussian_function - Laplacian/Laplacian of Gaussian
http://homepages.inf.ed.ac.uk/rbf/HIPR2/log.htm - Odstranění šumu
https://cs.wikipedia.org/wiki/Odstran%C4%9Bn%C3%AD_%C5%A1umu - Binary image
https://en.wikipedia.org/wiki/Binary_image - Erosion (morphology)
https://en.wikipedia.org/wiki/Erosion_%28morphology%29 - Dilation (morphology)
https://en.wikipedia.org/wiki/Dilation_%28morphology%29 - Mathematical morphology
https://en.wikipedia.org/wiki/Mathematical_morphology - Cvičení 10 – Morfologické operace
http://midas.uamt.feec.vutbr.cz/ZVS/Exercise10/content_cz.php - Differences between a matrix and a tensor
https://math.stackexchange.com/questions/412423/differences-between-a-matrix-and-a-tensor - Qualitatively, what is the difference between a matrix and a tensor?
https://math.stackexchange.com/questions/1444412/qualitatively-what-is-the-difference-between-a-matrix-and-a-tensor? - BLAS (Basic Linear Algebra Subprograms)
http://www.netlib.org/blas/ - Basic Linear Algebra Subprograms (Wikipedia)
https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms - Comparison of deep learning software
https://en.wikipedia.org/wiki/Comparison_of_deep_learning_software - TensorFlow
https://www.tensorflow.org/ - Caffe2 (A New Lightweight, Modular, and Scalable Deep Learning Framework)
https://caffe2.ai/ - PyTorch
http://pytorch.org/ - Seriál o programovacím jazyku Lua
http://www.root.cz/serialy/programovaci-jazyk-lua/ - LuaJIT – Just in Time překladač pro programovací jazyk Lua
http://www.root.cz/clanky/luajit-just-in-time-prekladac-pro-programovaci-jazyk-lua/ - LuaJIT – Just in Time překladač pro programovací jazyk Lua (2)
http://www.root.cz/clanky/luajit-just-in-time-prekladac-pro-programovaci-jazyk-lua-2/ - LuaJIT – Just in Time překladač pro programovací jazyk Lua (3)
http://www.root.cz/clanky/luajit-just-in-time-prekladac-pro-programovaci-jazyk-lua-3/ - LuaJIT – Just in Time překladač pro programovací jazyk Lua (4)
http://www.root.cz/clanky/luajit-just-in-time-prekladac-pro-programovaci-jazyk-lua-4/ - LuaJIT – Just in Time překladač pro programovací jazyk Lua (5 – tabulky a pole)
http://www.root.cz/clanky/luajit-just-in-time-prekladac-pro-programovaci-jazyk-lua-5-tabulky-a-pole/ - LuaJIT – Just in Time překladač pro programovací jazyk Lua (6 – překlad programových smyček do mezijazyka LuaJITu)
http://www.root.cz/clanky/luajit-just-in-time-prekladac-pro-programovaci-jazyk-lua-6-preklad-programovych-smycek-do-mezijazyka-luajitu/ - LuaJIT – Just in Time překladač pro programovací jazyk Lua (7 – dokončení popisu mezijazyka LuaJITu)
http://www.root.cz/clanky/luajit-just-in-time-prekladac-pro-programovaci-jazyk-lua-7-dokonceni-popisu-mezijazyka-luajitu/ - LuaJIT – Just in Time překladač pro programovací jazyk Lua (8 – základní vlastnosti trasovacího JITu)
http://www.root.cz/clanky/luajit-just-in-time-prekladac-pro-programovaci-jazyk-lua-8-zakladni-vlastnosti-trasovaciho-jitu/ - LuaJIT – Just in Time překladač pro programovací jazyk Lua (9 – další vlastnosti trasovacího JITu)
http://www.root.cz/clanky/luajit-just-in-time-prekladac-pro-programovaci-jazyk-lua-9-dalsi-vlastnosti-trasovaciho-jitu/ - LuaJIT – Just in Time překladač pro programovací jazyk Lua (10 – JIT překlad do nativního kódu)
http://www.root.cz/clanky/luajit-just-in-time-prekladac-pro-programovaci-jazyk-lua-10-jit-preklad-do-nativniho-kodu/ - LuaJIT – Just in Time překladač pro programovací jazyk Lua (11 – JIT překlad do nativního kódu procesorů s architekturami x86 a ARM)
http://www.root.cz/clanky/luajit-just-in-time-prekladac-pro-programovaci-jazyk-lua-11-jit-preklad-do-nativniho-kodu-procesoru-s-architekturami-x86-a-arm/ - LuaJIT – Just in Time překladač pro programovací jazyk Lua (12 – překlad operací s reálnými čísly)
http://www.root.cz/clanky/luajit-just-in-time-prekladac-pro-programovaci-jazyk-lua-12-preklad-operaci-s-realnymi-cisly/ - Lua Profiler (GitHub)
https://github.com/luaforge/luaprofiler - Lua Profiler (LuaForge)
http://luaforge.net/projects/luaprofiler/ - ctrace
http://webserver2.tecgraf.puc-rio.br/~lhf/ftp/lua/ - The Lua VM, on the Web
https://kripken.github.io/lua.vm.js/lua.vm.js.html - Lua.vm.js REPL
https://kripken.github.io/lua.vm.js/repl.html - lua2js
https://www.npmjs.com/package/lua2js - lua2js na GitHubu
https://github.com/basicer/lua2js-dist - Lua (programming language)
http://en.wikipedia.org/wiki/Lua_(programming_language) - LuaJIT 2.0 SSA IRhttp://wiki.luajit.org/SSA-IR-2.0
- The LuaJIT Project
http://luajit.org/index.html - LuaJIT FAQ
http://luajit.org/faq.html - LuaJIT Performance Comparison
http://luajit.org/performance.html - LuaJIT 2.0 intellectual property disclosure and research opportunities
http://article.gmane.org/gmane.comp.lang.lua.general/58908 - LuaJIT Wiki
http://wiki.luajit.org/Home - LuaJIT 2.0 Bytecode Instructions
http://wiki.luajit.org/Bytecode-2.0 - Programming in Lua (first edition)
http://www.lua.org/pil/contents.html - Lua 5.2 sources
http://www.lua.org/source/5.2/ - REPL
https://en.wikipedia.org/wiki/Read%E2%80%93eval%E2%80%93print_loop - The LLVM Compiler Infrastructure
http://llvm.org/ProjectsWithLLVM/ - clang: a C language family frontend for LLVM
http://clang.llvm.org/ - LLVM Backend („Fastcomp“)
http://kripken.github.io/emscripten-site/docs/building_from_source/LLVM-Backend.html#llvm-backend - Lambda the Ultimate: Coroutines in Lua,
http://lambda-the-ultimate.org/node/438 - Coroutines Tutorial,
http://lua-users.org/wiki/CoroutinesTutorial - Lua Coroutines Versus Python Generators,
http://lua-users.org/wiki/LuaCoroutinesVersusPythonGenerators