6 See one, do one, teach one! Machine Learning in R
Nachdem wir uns nun bereits in deskriptive Statistiken und statistische Tests eingearbeitet haben, können wir einen Schritt weiter gehen und kommen zu dem Thema, das aktuell mehr denn je in aller Munde ist - Maschinelles Lernen bzw. Machine Learning. In seinen einfacheren Ausprägungen ist maschinelles Lernen nichts anderes als die Erarbeitung eines statistischen Modells, welches dann wiederum auf neue Daten angewandt werden kann, um bspw. Vorhersagen zu treffen.
Heute werden die Begriffe “Künstliche Intelligenz”, “Maschinelles Lernen” und “Neuronale Netzwerke” teils synonym verwandt. Korrekterweise kann man jedoch sagen, dass die neuronalen Netze ledigliche ein Teilgebiet des maschinellen Lernens sind, das wiederum ein Teil von dem bezeichnet was unter künstlicher Intelligenz verstanden wird. Im folgenden werden wir einige einfachere Algorithmen kennenlernen und in R ausprobiern. Wer danach noch Hunger auf mehr Machine Learning Algorithmen und deren Anwendung in R hat, dem sei die Webseite 101 Machine Learning Algorithms ans Herz gelegt.
6.1 Lernziele
- Verschiedene Machine Learning Algorithmen in R nutzen
- Daten in Training- und Testdaten unterteilen
- Trainierte Modelle auf neue Daten anwenden
- Klassifikationsgüte berechnen
6.2 Das Iris Dataset
In vielen Beipsielen wird online auf das sogenannte Iris Dataset Bezug genommen. Der Datensatz hat deshalb einige Berühmtheit erlangt, der Einfachheit halber verwenden wir ihn deshalb auch für dieses Webinar.
data("iris")
str(iris)
## 'data.frame': 150 obs. of 5 variables:
## $ Sepal.Length: num 5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
## $ Sepal.Width : num 3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...
## $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
## $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
## $ Species : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...
In dem Datensatz enthalten sind Beobachtungen zu 150 verschiedenen Schwertlilien enthalten, jede Beobachtung enthält Angaben zu Länge und Breite der Kelch- und Kronblätter sowie zur Zugehörigkeit zu einer von drei Schwertlilienarten. Die übliche Aufgabe ist es dann anhand dieses Datensatzes einen Algorithmus zu trainieren, der aus den Angaben zur Länge und Breite der Kelch- und Kronblätter Vorhersagen über die Artzugehörigkeit trifft.
6.3 Explorative Datenvisualisierung
Ein guter Anfang ist meist sich einen visuellen Überblick über die Daten zu verschaffen. Man könnte beispielsweise die Verteilung der Messwerte als gruppierte Punktwolken darstellen.
%>%
iris pivot_longer(-Species, names_to = "variable") %>%
ggplot(aes(x = Species, y = value, color = Species)) +
geom_jitter() +
facet_wrap(vars(variable))
Wie man sieht, müsste es möglich sein anhand der Daten die Artzugehörigkeit abzuschätzen. Rein visuell könnte man vermuten, dass sich dafür insbesondere die Länge der Kronblätter (Petal-Length
) eignen müsste.
6.4 Lineare Regression
Einer der einfachsten Machine Learning Algorithmen, die Lineare Regression eignet sich zwar nicht zur Vorhersage eines kategorialen Variable, wie sie die Artzugehörigkeit ist, sollte aber hier trotzdem nicht unerwähnt bleiben. Üblicherweise versucht eine lineare Regression einen kontinuierlichen Zahlenwert für eine abhängige Variable aus einer oder mehreren unabhängigen Variablen zu berechnen. In unserem Fall könnten wir beispielsweise versuchen die Breite der Kronblätter (Petal.Width
) aus den übrigen Variablen abzuschätzen.
Hierzu nutzen wir die Funktion lm()
. Erster Parameter dieser Funktion ist eine Formel, wie wir sie bereits im Kapitel zu den statistischen Tests benutzt haben (Abschnitt 4.3.1). In Formeln kann der Punkt .
genutzt werden, um alle Variablen (bzw. Spalten) außer der links der Tilde ~
angegebenen zu referenzieren. Der zweite Parameter der Funktion ist das Dataframe, das genutzt werden soll. Da es hier keinen Sinn machen würde die Art einzuschließen, nutzen wir innerhalb der lm()
-Funktion ein select()
, um die Variable Species
auszuschließen.
<- lm(formula = Petal.Width ~ ., data = iris %>% select(-Species))
fit summary(fit)
##
## Call:
## lm(formula = Petal.Width ~ ., data = iris %>% select(-Species))
##
## Residuals:
## Min 1Q Median 3Q Max
## -0.60959 -0.10134 -0.01089 0.09825 0.60685
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) -0.24031 0.17837 -1.347 0.18
## Sepal.Length -0.20727 0.04751 -4.363 2.41e-05 ***
## Sepal.Width 0.22283 0.04894 4.553 1.10e-05 ***
## Petal.Length 0.52408 0.02449 21.399 < 2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 0.192 on 146 degrees of freedom
## Multiple R-squared: 0.9379, Adjusted R-squared: 0.9366
## F-statistic: 734.4 on 3 and 146 DF, p-value: < 2.2e-16
Die Ausgabe ist auf den ersten Blick nicht besonders eingängig, zeigt aber im Wesentlichen schon, dass signifikante Zusammenhänge zwischen Kronblattbreite und allen anderen Variablen existieren.
Natürlich können auch die Ergebnisobjekte von Machine Learning Algorithmen für die weitere Nutzung mit Tidyverse-Paketen aufbereitet werden (siehe Abschnitt 4.4.1).
library(broom)
tidy(fit)
## # A tibble: 4 × 5
## term estimate std.error statistic p.value
## <chr> <dbl> <dbl> <dbl> <dbl>
## 1 (Intercept) -0.240 0.178 -1.35 1.80e- 1
## 2 Sepal.Length -0.207 0.0475 -4.36 2.41e- 5
## 3 Sepal.Width 0.223 0.0489 4.55 1.10e- 5
## 4 Petal.Length 0.524 0.0245 21.4 7.33e-47
%>%
fit tidy(conf.int = TRUE) %>%
filter(term != "(Intercept)") %>%
mutate(p.value.chr = format.pval(p.value, digits = 2)) %>%
ggplot(aes(x = estimate, y = fct_reorder(term, estimate))) +
geom_point() +
geom_errorbar(aes(xmin = conf.low, xmax = conf.high), width = 0.15) +
geom_vline(xintercept = 0, linetype = "dashed") +
geom_text(aes(label = paste("p-value =", p.value.chr)), nudge_y = 0.2) +
scale_x_continuous(limits = c(-0.8, 0.8)) +
labs(y = "",
x = "Estimate",
title = "Linear regression coefficients for Petal.Width")
6.5 Logistische Regression
Einer der einfachtsten Machine Learning Algorithmen zur Vorhersage einer binären Entscheidung (wie bspw. benigne vs. maligne) ist die Logistische Regression. Das Vorgehen in R ist hierbei weitestgehend gleich zu dem bei der linearen Regression, nur dass wir in diesem Fall eben als abhängige Variable eine Variable nehmen müssen, die entweder als Character
oder als Factor
vorliegt und nur zwei Ausprägungen hat.
Im Iris-Datensatz könnten wir also beispeilsweise versuchen anhand von Blattlängen und -breiten zwischen den Arten virginica
und versicolor
zu unterscheiden. Dazu nutzen wir die Funktion glm()
und geben dieser als Parameter family = "binomial
mit.
# Erstellen einer Teilmenge des Datensatzes mit nur zwei Spezies
<- iris %>%
iris_binominal filter(Species %in% c("virginica", "versicolor")) %>%
# da die Variable Species weiterhin als Factor mit drei Level
# angelegt wäre, nutzen wir die droplevels()-Funktion, um
# Fehler zu vermeiden.
droplevels()
# family = "binomial" gibt hier an, dass nur eine binäre Entscheidung zu treffen ist
<- glm(Species ~ ., data = iris_binominal, family = "binomial")
fit summary(fit)
##
## Call:
## glm(formula = Species ~ ., family = "binomial", data = iris_binominal)
##
## Deviance Residuals:
## Min 1Q Median 3Q Max
## -2.01105 -0.00541 -0.00001 0.00677 1.78065
##
## Coefficients:
## Estimate Std. Error z value Pr(>|z|)
## (Intercept) -42.638 25.707 -1.659 0.0972 .
## Sepal.Length -2.465 2.394 -1.030 0.3032
## Sepal.Width -6.681 4.480 -1.491 0.1359
## Petal.Length 9.429 4.737 1.991 0.0465 *
## Petal.Width 18.286 9.743 1.877 0.0605 .
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
##
## Null deviance: 138.629 on 99 degrees of freedom
## Residual deviance: 11.899 on 95 degrees of freedom
## AIC: 21.899
##
## Number of Fisher Scoring iterations: 10
In diesem konkreten Beispiel erscheint es so, dass lediglich die Kronblattlänge (Petal.Length
) signifikant mit der Artzugehörigkeit zusammenhängt. In einem “echten” Projekt könnte man jetzt versuchen das Modell zu optimieren und nicht signifikante Faktoren entfernen, für unsere Zwecke arbeiten wir aber jetzt mit diesem (sicher nicht idealen) Modell weiter.
Wir können nun das in der Variablen fit
gespeicherte Modell mithilfe der Funktion predict()
nutzen, und uns die Vorhersagewerte des Modells zu ausgeben. Mit einigen praktischen Befehlen können wir die Vorhersagen auch gleich als neue Spalte dem ursprünglichen Dataframe hinzufügen.
# wir nutzen hier die add_column()-Funktion aus dem tidyverse
# und die predict()-Funktion um Werte mithilfe eines Modells
# zu berechnen
<- iris_binominal %>%
iris_binominal_predictions add_column(prediction_value = predict(fit))
# ein Blick in die ersten Zeilen der Daten
head(iris_binominal_predictions)
## Sepal.Length Sepal.Width Petal.Length Petal.Width Species prediction_value
## 1 7.0 3.2 4.7 1.4 versicolor -11.354482
## 2 6.4 3.2 4.5 1.5 versicolor -9.932613
## 3 6.9 3.1 4.9 1.5 versicolor -6.725380
## 4 5.5 2.3 4.0 1.3 versicolor -10.073036
## 5 6.5 2.8 4.6 1.5 versicolor -6.563842
## 6 5.7 2.8 4.5 1.3 versicolor -9.191831
Wie man sieht, gibt das Modell der logistischen Regression uns einen numerieschen Wert (in unserem Fall in der Spalte prediction_value
) zurück, der in gewisser Weise aber die Zugehörigkeit zu der jeweiligen Klasse wiederspiegelt. Diese Daten könnten wir nun für eine einfache Visualisierung der Klassifikationsgenauigkeit nutzen.
%>%
iris_binominal_predictions ggplot(aes(y = prediction_value, x = Species, color = Species)) +
geom_jitter() +
geom_hline(yintercept = 0, linetype = "dashed") +
labs(x = "Spezies",
y = "Ausgabewert der logistischen Regression",
color = "") +
theme(legend.position = "none")
Wie wir aus der Grafik unschwer erkennen, scheinen positive Werte für die Ausgabe der logistischen Regression eher für Schwertlilien der Art virginica
zu sprechen, negative Werte für die Art versicolor
. Nur in je einem Fall liegt dieses Modell für beide Arten liegt das Modell falsch.
6.5.1 Ein kurzer Ausflug in ROC-Analysen
Ohne zu sehr darauf eingehen zu wollen, an dieser Stelle ein kurzer Exkurs zu ROC-Analysen. In einer ROC-Kurve wird sozusagen jede Kombination von Sensitivität und Spezifität aufgetragen, die sich in Abhängigkeit von der Variation eines Parameters (in unserem Fall also bspw. der Ausgabewert der logistischen Regression) ergibt. Je näher die Fläche unterhalb der ROC-Kurve am Wert 1 ist, desto besser die diagnostische oder prädiktive Performance eines Modells.
In R können ROC-Kurven und dazugehörige Berechnungen sehr einfach mit dem Paket pROC
vorgenommen werden. Im Folgenden nehmen wir das Dataframe mit den ergänzten Vorhersagewerten und teilen der Funktion roc()
als ersten Parameter den wahren Wert (hier also Species
) und als zweiten Parameter den vorhersagenden Wert (hier also prediction_value
) mit.
library(pROC)
%>%
iris_binominal_predictions roc(Species, prediction_value)
##
## Call:
## roc.data.frame(data = ., response = Species, predictor = prediction_value)
##
## Data: prediction_value in 50 controls (Species versicolor) < 50 cases (Species virginica).
## Area under the curve: 0.9972
Wie wir erwartet hätte funktioniert unser Modell sehr gut und erreicht eine AUC von 0.99.
Noch eingängiger wäre nun auch eine entsprechende grafische Darstellung der Kurve. Erfreulicherweise enthält das Paket pROC
auch hierfür Funktionen, insbesondere die Funktion ggroc()
, die direkt eine in ggplot
weiter modifizierbare Grafik erzeugt.
%>%
iris_binominal_predictions roc(Species, prediction_value) %>%
ggroc(color = "steelblue", size = 1) +
geom_segment(aes(x = 1, xend = 0, y = 0, yend = 1),
color="grey", linetype="dashed") +
coord_fixed()
Wollen wir diese Grafik dann noch bspw. um weitere wichtige Informationen ergänzen, können wir auch ganz das einfach erreichen.
# zunächst speichern wir unsere ROC-Analyse in einer Variablen
<- iris_binominal_predictions %>%
roc_result roc(Species, prediction_value)
# AUC und optimaler threshold werden berechnet und in Variablen gespeichert
<- auc(roc_result)
roc_auc <- coords(roc_result, "best", transpose = FALSE)
roc_optimal_point
# ROC-Kurve mit Anmerkungen
%>%
roc_result ggroc(color = "steelblue", size = 1) +
geom_segment(aes(x = 1, xend = 0, y = 0, yend = 1),
color="grey", linetype="dashed") +
coord_fixed() +
annotate("text", x = 0.05, y = 0.1, label = paste("AUC =", roc_auc), hjust = 1) +
geom_text(data = roc_optimal_point,
aes(x = specificity, y = sensitivity, label = paste("Best operating point =", round(threshold, 2),
"\nSensitivity:", round(sensitivity, 2),
"\nSpecificity:", round(specificity, 2)
)),hjust = 0, vjust = 1,
nudge_x = 0.02, nudge_y = -0.02,
size = 3) +
geom_point(aes(x = roc_optimal_point$specificity, y = roc_optimal_point$sensitivity), color = "orange", shape = 4, size = 4)
6.6 Datenaufbereitung für Machine Learning
Für die Anwendung “komplizierterer” Machine Learning Algorithmen sollte man sich an einige wichtige Regel halten. Ohne, dass wir hier alle Details bedenken können, sind zwei wichtige Schritte sicherlich die Normalisierung der Daten und die Aufteilung in Trainings- und Testdaten.
Für beide Arbeitsschritte können wir Funktionen aus dem Paket caret
benutzen. Für die Normalisierung der Daten legen wir zunächst mit der Funktion preProcess()
ein Objekt an, das die relevanten Informationen enthält, dieses wiederum wenden wir dann mit der schon bekannten Funktion predict()
an.
library(caret)
<- preProcess(iris, method=c("range"))
preprocess_object
<- predict(preprocess_object, iris)
iris_normalized
%>%
iris_normalized pivot_longer(-Species, names_to = "variable") %>%
ggplot(aes(x = Species, y = value, color = Species)) +
geom_jitter() +
facet_wrap(vars(variable)) +
labs(y = "normalized value")
Man erkennt, dass nun alle Werte auf einen Bereich zwischen 0 und 1 abgebildet wurden, was bspw. die Vergleichbarkeit (auch visuell, vgl. Abschnitt 6.3) verbessern kann.
Im nächsten Schritt wollen wir die Daten aufteilen, sodass wir einen Datensatz erhalten, den wir zum Training unseres Algorithmus nutzen, und einen Datensatz, den wir zum Testen bzw. Validieren unseres Modells nutzen. Dabei sollte gewährleistet werden, dass in beiden Gruppen alle Werte annähernd ähnlich verteilt sind, die Zuordnung aber möglichst trotzdem zufällig stattfindet. Auch hier können wir uns einer Funktion des Pakets caret
bedienen, in dem Fall die Funktion createDataPartition()
. Dieser Funktion müssen wir als Parameter noch die Spalte mitgeben, die unsere vorherzusagende Variable enthält (in unserem Fall also iris$Species
), welcher Prozentsatz der Daten dem Training zugeordnet werden soll (z.B.: 80%) und ob wir ein list
-Objekt erhalten wollen (wenn wir nur eine Aufteilung machen wollen, sollte man hier FALSE
wählen). Da die Funktion uns nur die Position der Zeilen des Dataframes, die zur Trainingsgruppe gehören sollen, gibt, ist es praktischer diese in einer Variablen zu speichern, die wir dann nutzen, um aus dem ursprünglichen Dataframe die entsprechenden Zeilen zu selektieren bzw. zu entfernen.
<- createDataPartition(iris$Species, p = .8, list = FALSE)
training_index
<- iris[training_index,]
iris_training <- iris[-training_index,] iris_testing
Um das Ergebnis zu kontrollieren, können wir bspw. die bekannten Funktionen zur Erstellung deskriptiver Statistiken nutzen (bspw. die Funktionen aus dem Paket summarytools
, siehe auch Abschnitt 3.4.2).
# Da der Output recht lang ist, führen wir hier nur die Befehle auf.
library(summarytools)
dfSummary(iris_training)
dfSummary(iris_testing)
Sobald wir unsere Daten entsprechend aufgeteilt haben, können wir anfangen unsere Modelle zu trainieren.
6.7 Support Vector Machine
Ein gern und häufig verwendeter, etwas komplexerer Machine Learning Algorithmus zur Klassifikation auch mehrerer Gruppen ist die Support Vector Machine. Der Algorithmus versucht hierbei möglichst gute Trennungen zwischen den Klassen (auch mehreren) in einem sogenannten Vektorraum zu finden.
In R gibt es mehrere Pakete, die Support Vector Machines zur Verfügung stellen, ein Beispiel wäre das Paket e1071
. Das Vorgehen ist dem bei der logistischen Regression (siehe 6.5) sehr ähnlich. Der Funktion svm()
übergeben wir auch in diesem Falle eine Formel und den entsprechenden Datensatz. Hier ist es allerdings wichtig nur den Trainingsdatensatz zu verwenden.
library(e1071)
<- svm(Species ~ ., data = iris_training)
fit summary(fit)
##
## Call:
## svm(formula = Species ~ ., data = iris_training)
##
##
## Parameters:
## SVM-Type: C-classification
## SVM-Kernel: radial
## cost: 1
##
## Number of Support Vectors: 46
##
## ( 8 20 18 )
##
##
## Number of Classes: 3
##
## Levels:
## setosa versicolor virginica
Ein solches Modell erschließt sich entsprechend nicht ganz so intuitiv wie bspw. eine linieare Regression (6.4). Für die meisten unserer Anwendungen muss dies aber kein Nachteil sein, es sollte aber bei solchen Algorithmen in besonderem Maße darauf geachtet werden, dass letztlich nur die Ergebnisse in dem Testdatensatz bedeutend sind. Gute Ergebnisse in den Trainingsdaten können auch lediglich Ausdruck eines sogenannten Overfittings sein und sind nicht immer auf die Testdaten übertragbar.
In einem zweiten Schritt wenden wir, wie auch schon bei den Regressionen, daher nun das trainierte Modell auf neue Daten an - in diesem Fall den Testdatensatz, der nicht Teil des Trainingsprozesses war.
<- iris_testing %>%
iris_test add_column(prediction = predict(fit, iris_testing))
Zur Bewertung der diagnostischen Genauigkeit des Modells sind eine Reihe Maßzahlen üblich. Erfreulicherweise enthält das o.g. Paket caret
auch einige praktische Funktionen, um diese einfach berechnen zu lassen. Die Funktion confusionMatrix()
benötigt dazu als Parameter eine Kreuztabelle, die wir in R wiederum mithilfe der Funktion table()
erstellen können.
table(iris_test$Species, iris_test$prediction) %>%
confusionMatrix()
## Confusion Matrix and Statistics
##
##
## setosa versicolor virginica
## setosa 10 0 0
## versicolor 0 9 1
## virginica 0 0 10
##
## Overall Statistics
##
## Accuracy : 0.9667
## 95% CI : (0.8278, 0.9992)
## No Information Rate : 0.3667
## P-Value [Acc > NIR] : 4.476e-12
##
## Kappa : 0.95
##
## Mcnemar's Test P-Value : NA
##
## Statistics by Class:
##
## Class: setosa Class: versicolor Class: virginica
## Sensitivity 1.0000 1.0000 0.9091
## Specificity 1.0000 0.9524 1.0000
## Pos Pred Value 1.0000 0.9000 1.0000
## Neg Pred Value 1.0000 1.0000 0.9500
## Prevalence 0.3333 0.3000 0.3667
## Detection Rate 0.3333 0.3000 0.3333
## Detection Prevalence 0.3333 0.3333 0.3333
## Balanced Accuracy 1.0000 0.9762 0.9545
Die Funktion confusionMatrix()
gibt erfreulicherweise auch gleich die Kreuztabelle mit aus. Dabei entsprechen die Zeilen dem “wahren” Label oder Wert und die Spalten der Vorhersage des Algorithmus. Wir sehen also nun, dass sich der Algorithmus nur in zwei Fällen im Testdatensatz vertan hat (beide Male gab der Algorithmus Art versicolor
aus, obwohl es sich um Schwertlilien der Art virginica
handelte).
6.8 Weitere Informationen zu Machine Learning in R
Natürlich können wir an dieser Stelle nicht alle Machine Learning Algorithmen behandeln, die man in R nutzen kann. Eine tolle Zusammenstellung vieler Algorithmen, inkl. Tutorials in R, findet man auf der Webseite DataScienceDojo. Wie man schnell sieht sind die hier verwendeten nur ein kleiner Ausschnitt. Trotzdem, alleine mit diesen vier Paketen hat man viele der typischerweise in der Radiologie angewandten Techniken zur Verfügung: