© Enkel/Shutterstock.com
Im Überblick

Tribuo Concepts


ÜBERBLICK: BASIC TYPES

Dukes_Infografik1.jpg_fmt1.png

Tribuo ist eine statisch und stark typisierte, in Java geschrieben Machine-Learning-Bibliothek. Sie arbeitet mit Objekten und nicht mit mehrdimensionalen Arrays, d. h., dass die Modelle Example-Objekte akzeptieren und typisierte Prediction-Outputs erzeugen. Tribuo verwendet auch das Typsystem, um zu erzwingen, dass die Trainings-Data-Sets mit den erwarteten Modellen übereinstimmen (zum Beispiel, dass Classification Trainer auf einem Classification Data Set verwendet wird).

Data Sets werden über das DataSource Interface in Tribuo geladen, hier laden wir das beliebte MNIST Data Set für handgeschriebene Zahlen:

var labelFactory = new LabelFactory(); var trainDataSource = new IDXDataSource<>(Paths.get("./train-images-idx3-ubyte.gz"),Paths.get("./train-labels-idx1-ubyte.gz"),labelFactory);

DataSource ist ein (möglicherweise lazy) Iterable über die Beispiele in der angegebenen Quelle. Data Sources führen die Featurisierung der Daten in ein für das Training geeignetes Beispiel durch. Featurisierung bedeutet die Umwandlung der Daten in ein spärliches Array von Werten (z. B. Pixelwerte für ein Bild, Wortanzahl für Text usw.).

Anschließend laden wir die Data Source in ein Data Set. Data Sets tracken die Domänen der Features und der Outputs, sie kennen die Bereiche der Features und Outputs der Beispiele, die sie enthalten. Nachdem die Daten in ein Data Set geladen wurden, können sie auf verschiedene Weise transformiert werden, z. B. durch Neuskalierung der Features zwischen null und eins oder durch Binning der Featurewerte usw. Diese Transformationen werden im Data-Set-Objekt nachverfolgt, sodass wir sie später wiederherstellen können.

var trainDataset = new MutableDataset<>(trainDataSource);

Data Sources und Data Sets sind Sammlungen von Example<T>, wobei der Outputtask im Typ kodiert ist (z. B. verwenden Classifications Label, Regressions verwenden Regressor). Jedes Example ist eine implizit spärliche (sparse) Liste von Featureobjekten, wobei ein Feature ein Name und ein Wert ist. Das macht Tribuo besonders geeignet für Natural-Language-Processing-Tasks, da es sich bei Dokumenten um spärliche Sammlungen von Wörtern handelt. Featurenamen anstelle von Featureindizes machen es einfach, herauszufinden, ob das Modell ein bestimmtes Wort schon einmal gesehen hat – es ist in der Domäne gespeichert.

Bevor wir ein Modell trainieren können, müssen wir zunächst einen Trainer erstellen und dessen Hyperparameter festlegen. Wir erstellen eine einfache logistische Regression:

var trainer = new LinearSGDTrainer(new LogMulticlass(), new AdaGrad(0.5), 2, 1);

Dies ist ein linearer Modelltrainer, der einen logarithmischen Verlust verwendet (daher ist es eine logistische Regression), trainiert mit dem AdaGrad-Gradientenabstiegsoptimierer für drei Epochen (d. h. drei Durchläufe durch die Trainingsdaten). Jetzt sind wir bereit, ein Modell zu trainieren:

var model = trainer.train(trainDataset);

Dieser Aufruf protokolliert Informationen über das Training in den Logger und erstellt ein Modell. Die Domänen aus dem Datensatz werden in das Modell kopiert, zusammen mit den Beschreibungen aller Transformationen und der ursprünglichen Datenquelle. Unser Modell ist nun einsatzbereit, entweder durch Einspeisung neuer Beispiele oder durch Evaluierung seiner Leistung anhand eines Test-Sets. Wir werden Letzteres tun:

var testDataSource = new IDXDataSource<>(Paths.get("./t10k-images-idx3-ubyte.gz"),Paths.get("./t10k-labels-idx1-ubyte.gz"),labelFactory); var evaluation = new LabelEvaluator().evaluate(model,testDataSource);

Für jede der verschiedenen Prediction-Tasks, die Tribuo unterstützt, gibt es Auswertungen, die die relevanten Metriken anzeigen. Wir können eine spezifische Metrik wie die Genauigkeit anzeigen oder einen formatierten String erzeugen, der die häufig verwendeten Metriken darstellt.

double accuracy = evaluation.accuracy(); String formattedOutput = evaluation.toString();

Der formattedOutput-String sieht folgendermaßen aus:

Class n tp fn fp recall prec f1 0 980 904 76 21 0.922 0.977 0.949 1 1,135 1,072 63 18 0.944 0.983 0.964 2 1,032 856 176 56 0.829 0.939 0.881 3 1,010 844 166 84 0.836 0.909 0.871 4 982 888 94 72 0.904 0.925 0.915 5 892 751 141 143 0.842 0.840 0.841 6 958 938 20 139 0.979 0.871 0.922 7 1,028 963 65 133 0.937 0.879 0.907 8 974 892 82 363 0.916 0.711 0.800 9 1,009 801 208 62 0.794 0.928 0.856 Total 10,000 8,909 1,091 1,091 Accuracy 0.891 Micro Average 0.891 0.891 0.891 Macro Average 0.890 0.896 0.890 Balanced Error Rate 0.110

Output Types

Tribuo kodiert den Prediction-Task im Typsystem als Unterklassen von Output. Jede Output-Unterklasse kommt mit einer Familie von unterstützenden Kernklassen für diesen Prediction-Task. Tribuo definiert eine Reihe von Schnittstellen für seine Output-Klassen:

  • Output repräsentiert den vorhergesagten Wert selbst, zusammen mit einer eventuellen Bewertung.

  • OutputInfo repräsentiert die Domäne der Outputs und verfolgt relevante Statistiken.

  • Evaluator und Evaluation repräsentieren die Auswertungsfunktion und die Ergebnisse.

  • OutputFactory erzeugt Outputs, Outputinfos und Evaluatoren des entsprechenden Typs.

OutputFactory hat eine zusätzliche Verwendung: Es enthält den „Unknown“-Sentinel-Output, der verwendet wird, wenn der Ground-Truth-Output nicht bekannt ist (z. B. zur Inferenz-/Deployment-Zeit).

OutputInfo-Implementierungen gibt es in zwei Formen: MutableOutputInfo und ImmutableOutputInfo. Die veränderliche Form wird innerhalb von Data Sets verwendet, die ihre Domänen tracken, und die unveränderliche Form wird innerhalb von Modellen und unveränderlichen Data Sets verwendet. Die unveränderliche Form enthä...

Neugierig geworden? Wir haben diese Angebote für dich:

Angebote für Gewinner-Teams

Wir bieten Lizenz-Lösungen für Teams jeder Größe: Finden Sie heraus, welche Lösung am besten zu Ihnen passt.

Das Library-Modell:
IP-Zugang

Das Company-Modell:
Domain-Zugang