© Sashkin/Shutterstock.com
In einfachen Schritten mit Deep Learning durchstarten

Schnelleinstieg in Deeplearning4j


Dieser Artikel zeigt, wie man in kürzester Zeit den Einstieg in Deeplearning4j (DL4J) schafft. Anhand eines Beispiels, in dem vorhergesagt werden soll, ob ein Kunde seine Bank verlassen wird, wird jeder Schritt eines typischen Arbeitsablaufs betrachtet.

Deep Learning, d. h. die Verwendung von tiefen, mehrschichtigen neuronalen Netzen, ist der große Treiber des aktuellen Booms rund um Machine Learning. Von großen Sprüngen in der Qualität automatischer Übersetzungen über autonomes Fahren bis hin zum Schlagen von Großmeistern in dem Spiel Go macht diese Technik vielfach Schlagzeilen.

Deeplearning4j, auch DL4J genannt, ist eine Java-Bibliothek für Deep Learning. Zu ihr gehört zudem eine ganze Familie von weiteren Bibliotheken, die die Verwendung von Deep-Learning-Modellen mit Java vereinfachen. Als eine Alternative zu den vielen Python-basierten Frameworks bietet DL4J einen Weg, wie Deep Learning auch in Enterprise-Umgebungen einfach in den Produktivbetrieb gebracht werden kann. Der vollständige Code des Artikels, inklusive Trainingsdaten, befindet sich auf GitHub [1].

Einbindung ins Projekt

DL4J kann wie viele andere Java-Bibliotheken einfach als eine weitere Abhängigkeit in das Build-Tool der Wahl aufgenommen werden. In diesem Artikel werden die dafür notwendigen Angaben im Maven-Format gemacht, also so, wie sie in der pom.xml-Datei stehen würden. Natürlich kann man auch ein anderes Build-Tool wie Gradle oder SBT verwenden.

Eine Verwendung ohne Build-Tools ist für DL4J jedoch nicht vorgesehen, da es selbst auch eine Vielzahl von direkten und transitiven Abhängigkeiten hat. Deswegen gibt es auch keine einzelne .jar-Datei, die man manuell als Abhängigkeit in seiner IDE angeben könnte.

DL4J und die dazugehörigen Bibliotheken sind modular aufgebaut, sodass man seine Abhängigkeiten den Bedürfnissen des Projekts anpassen kann. Gerade für Einsteiger kann das jedoch die Verwendung verkomplizieren, da es nicht zwangsläufig offensichtlich ist, welches Untermodul benötigt wird, um eine bestimmte Klasse verfügbar zu machen.

Die verwendeten Versionen aller DL4J-Module sollten immer gleich sein. Um das zu vereinfachen, definieren wir eine Property, die wir im Folgenden immer verwenden werden, um die Versionsangabe zu machen. DL4J ist im Moment kurz vor seinem 1.0-Release und in diesem Artikel verwenden wir Version 1.0.0-beta2, die erst vor Kurzem erschienen ist.

<properties> <dl4j.version>1.0.0-beta2</dl4j.version> </properties>

Für Einsteiger ist es ratsam, mit dem Modul deeplearning4j-core zu beginnen. Das zieht viele weitere Module transitiv mit sich und erlaubt somit gleich die Nutzung einer Vielzahl von Features, ohne dass man sich auf die Suche nach der richtigen Abhängigkeit machen muss. Der Nachteil ist, dass man beim Bündeln aller Abhängigkeiten in ein Uber-JAR eine große Datei bekommt.

<dependency> <groupId>org.deeplearning4j</groupId> <artifactId>deeplearning4j-core</artifactId> <version>${dl4j.version}</version> </dependency>

DL4J unterstützt mehrere Backends, die die Verwendung von CPU oder GPU ermöglichen. Die Wahl des Backend geschieht dabei im einfachsten Fall über die Angabe einer Abhängigkeit. Um das CPU Backend zu verwenden, wird nd4j-native-platform benötigt. Für das GPU Backend wird hingegen nd4j-cuda-X.Y-platform verwendet, wobei X.Y durch die installierte CUDA-Version ausgetauscht werden sollte. Unterstützt werden im Augenblick CUDA 8.0, 9.0 und 9.2.

<dependency> <groupId>org.nd4j</groupId> <artifactId>nd4j-native-platform</artifactId> <version>${dl4j.version}</version> </dependency>

Beide Backends setzen auf die Verwendung von nativen Binarys, weswegen die Plattformmodule auch die Binaries für alle unterstützten Plattformen einbinden. Das ermöglicht die Verteilung auf mehrere unterschiedliche Plattformen, ohne dafür jeweils eine einzelne spezialisierte JAR-Datei erstellen zu müssen. Die aktuell unterstützten Plattformen für die jeweiligen CPUs sind:

  • Linux (PPC64LE, x86_64)

  • Windows (x86_64)

  • macOS (x86_64)

  • Android (ARM, ARM64, x86, x86_64)

Für CUDA-fähige GPUs:

  • Linux (PPC64LE, x86_64)

  • macOS (x86_64)

  • Windows (x86_64)

Aufgrund dessen, dass einige DL4J-eigene Abhängigkeiten noch nicht voll mit neueren Java-Versionen kompatibel sind, ist die hier verwendete Version auch nur mit maximal Java 8 verwendbar. Für das kommende Release der Version 1.0 ist jedoch eine volle Kompatibilität zu Java 11 geplant.

Zuletzt fügen wir noch einen Logger zu unseren Abhängigkeiten hinzu. DL4J benötigt einen mit dem SLF4J-API kompatiblen Logger, um seine Informationen mit uns zu teilen. Für unser Beispiel verwenden wir hier Logback Classic.

<dependency> <groupId>ch.qos.logback</groupId> <artifactId>logback-classic</artifactId> <version>1.2.3</version> </dependency>

ND4J: Fundament für Deeplearning4j

Wie schon an der Angabe des Backend zu sehen ist, bildet ND4J das Fundament, auf dem DL4J aufbaut. ND4J ist eine Bibliothek für schnelle Tensormathematik mit Java. Um die Hardware maximal auszunutzen, werden dabei praktisch alle Berechnungen außerhalb der JVM durchgeführt. Auf diese Weise können sowohl CPU-Features wie z. B. AVX-Vektorinstruktionen als auch GPUs verwendet werden.

Wenn eine GPU verwendet wird, sollte man jedoch berücksichtigen, dass gerade für Deep Learning oftmals eine recht potente GPU notwendig ist, um einen Geschwindigkeitsvorteil gegenüber einer CPU zu erlangen. Das gilt insbesondere für Notebook-GPUs, die vor der aktuellen GeForce-1000er-Serie erschienen sind, und selbst auf dem Desktop sollte man zumindest mit einer GeForce GTX 960 mit 4 GB RAM aufwarten. Der Grund für diese Empfehlung liegt darin, dass GPUs insbesondere bei Berechnungen mit großen Datenmengen glänzen – diese großen Datenmengen benötigen aber auch entsprechend viel RAM und dieser ist erst bei den stärkeren Modellen in ausreichender Menge vorhanden.

Laden der Daten

Machine Learning jeder Art beginnt zunächst einmal immer damit, dass Daten gesammelt und geladen werden müssen. Auch bei Deep Learning ist das nicht anders. Dieser Artikel konzentriert sich auf tabellarische Daten, die als CSV vorliegen. Das Vorgehen für andere Dateiformate oder -arten wie Bilder ist jedoch ähnlich.

Grundsätzlich gilt: Wenn man schnell zu guten Ergebnissen kommen möchte, sollte man seine Daten und das zu lösende Problem gut verstehen. Etwas Expertenwissen um die Daten und das generelle Problemfeld sowie eine entsprechende Vorbereitung der Daten können hierbei die Modellkomplexität und Trainingszeit in vielen Fällen bedeutend reduzieren.

Ein weiterer zu beachtender Punkt ist, dass man seine Daten für das Training eines Modells in mindestens zwei Teile aufteilen muss. Der Großteil der Daten, üblicherweise etwa 80 Prozent, wird zum Training verwendet und somit als Trainingsmenge bezeichnet. Die restlichen Daten, üblicherweise etwa 20 Prozent, werden zur Untersuchung der Qualität des Modells verwendet und als Testmenge bezeichnet. Insbesondere bei weitergehendem Tuning ist es sogar üblich, weitere 10 Prozent der Trainingsdaten für eine Validationsmenge zu reservieren, mit der man prüft, ob das Modell nicht zu sehr auf die Testmenge zugeschnitten wurde.

Bei der Wahl der Daten für die Testmenge sollte beachtet werden, dass sie eine repräsentative Teilmenge aller Daten ausmachen. Das ist notwendig, um die Aussagekraft des Modells ordnungsgemäß überprüfen zu können.

Der in diesem Artikel verwendete Datensatz [2] stammt von Kaggle, einer Plattform für Data-Science-und Machine-Learning-Wettbewerbe. Er besteht aus tabellarischen Daten und beinhaltet nicht nur rein numerische, sondern auch kategorische Daten. Er wurde bereits etwas vorverarbeitet und in Trainingsmenge und Testmenge aufgeteilt (Abb. 1).

dubs_dl4j_1.tif_fmt1.jpgAbb. 1: Kundendaten einer Bank

Der Datensatz besteht aus Kundendaten einer Bank. Jede Zeile stellt dabei einen Kunden dar und beinhaltet in der Spalte „Exited“ auch die Information, ob der Kunde die Bank verlassen hat. Die Problemstellung, die wir in diesem Beispiel lösen wollen, ist das Training eines Modells, das anhand dieser Kundendaten vorhersagen kann, ob ein Kunde die Bank verlassen wird. Es ist also ein klassisches Klassifikationsproblem mit zwei Klassen: „Kunde wird bleiben“ und „Wird die Bank verlassen“.

DataVec

Wie alle anderen statistischen Machine-Learning-Verfahren funktioniert auch Deep Learning nur mit numerischen Daten. DataVec ist eine DL4J-Bibliothek, die uns beim Laden, Analysieren und Konvertieren unserer Daten in das notwendige Format unterstützt. Um sie zu verwenden, müssen wir keine weitere Abhängigkeit angeben, da sie bereits vom Modul deeplearning4j-core mitgeladen wird.

Im Beispiel werden wir die drei Kernkonzepte von DataVec antreffen. Diese bestehen aus dem InputSplit, RecordReader und TransformProccess. Man kann sie als jene Schritte verstehen, die von den Daten durchlaufen werden müssen, um von Rohdaten zu tatsächlich verwendbaren Daten angereichert zu werden (Abb. 2).

dubs_dl4j_2.tif_fmt1.jpgAbb. 2: Der Weg von Rohdaten zu tatsächlich verwendbaren Daten

Wir beginnen damit, dass ein FileSplit erzeugt wird. Diesem geben wir den Ordner an, in dem sich unsere Trainingsmenge befindet. Die Aufgabe des InputSplits wird sein, dem RecordReader jeweils einen einzelnen Input anzubieten. Deswegen gibt es außer dem FileSplit auch noch eine Reihe weiterer Implementierungen, die z. B. Daten aus einem InputStream oder auch einer Collection bereitstellen können.

Dadurch, dass wir auch ein optionales Random-Objekt mitgeben, wird der FileSplit die Dateien in einer zufälligen Reihenfolge durchgehen. Das wird im späteren Training noch wichtig.

Random random = new Random(); random.setSeed(0xC0FFEE); FileSplit inputSplit = new FileSplit(new File("X:/Churn_Modelling/Train/"), random);

Unsere Daten liegen im CSV-Format vor, deswegen verwenden wir einen CSVRecordReader. Der wird mit dem zuvor erstellten InputSplit initialisiert. Der RecordReader wird im weiteren Verlauf den Input, den er vom InputSplit bekommt, nehmen und in ein oder mehrere Beispiele aufteilen. Diese Beispiele sind in Form von Records vom RecordReader abrufbar.

CSVRecordReader recordReader = new CSVRecordReader(); recordReader.initialize(inputSplit);

Auch für den RecordReader gibt es eine Reihe von anderen Implementierungen, die etwa Excel-Dateien, Bilder, Videos oder auch (via JDBC verbundene) Datenbanken lesen können.

Records sind im Prinzip nichts weiter als eine Liste von Werten. Und gerade im Fall vom CSVRecordReader sind diese Werte zunächst einmal alle String...

Neugierig geworden? Wir haben diese Angebote für dich:

Angebote für Gewinner-Teams

Wir bieten Lizenz-Lösungen für Teams jeder Größe: Finden Sie heraus, welche Lösung am besten zu Ihnen passt.

Das Library-Modell:
IP-Zugang

Das Company-Modell:
Domain-Zugang