Zum Hauptinhalt springen

Regularisierungspfade mit HTML-Widgets übersichtlich visualisieren - Eine übersichtliche Alternative zur Spaghettiknäuelvisualisierung

 

Das glmnet-Package, das die Elastic-Net-Regression von Hasti und Tibshirani implementiert, ist eines der Arbeitspferde der Data Science, zumindest soweit sie mit R betrieben wird. Das Paket hat einige eingebaute Möglichkeiten zur Visualisierung. Insbesondere kann man sich die Regularisierungspfade grafisch zeigen lassen. Diese Visualisierung hat jedoch ihre Tücken. Wenn das Modell viele Variablen enthält, sieht das Ergebnis aus wie ein Knäuel bunte Spaghetti. Es hilft dann beim Verständnis des Modellverhaltens nicht wirklich weiter. Leider benutzt man Elastic Net gerade dann besonders gern, wenn man viele bis furchtbar viele Variablen hat.

Eine bessere Visualisierung muss also her. Alle Informationen gleichzeitig darzustellen, wie man das in einem statischen Plot muss, stößt bei vielen Variablen an natürliche Grenzen. Irgendwann ist die Zeichenfläche halt voll. Ein Ausweg ist die Visualisierung mit HTML und JavaScript. Sie bietet die Möglichkeit, bei Berührung mit der Maus einen einzelnen Regularisierungspfad hervorzuheben, so dass sein Verlauf im Spaghettiknäuel deutlich wird. Gleichzeitig kann man zusätzliche Informationen anbieten, in unserem Fall den Variablennamen. Auf diese Weise wird das Spaghettiknäuel entzifferbar.

HTML-Widgets erzeugen mit vertrauter Syntax

Nun ist es eine der Krankheiten von R, dass es zwar für alles ein Paket gibt, dass aber meist jedes Paket seine sehr eigenen Befehlsvarianten verwendet. Die Syntax ist weit entfernt von einer Einheitlichkeit, wie man sie (trotz Package-Vielfalt) bei Python findet. Es ist also begrüßenswert, wenn sich irgendwo zarte Ansätze zur Vereinheitlichung zeigen. Für die Visualisierung ist das beim ggplot2-Universum der Fall. ggplot2 kennt fast jeder, der mit R arbeitet. Es ist (neben den eingebauten Grafikmöglichkeiten von R) das verbreitetste R-Paket zur Datenvisualisierung. Warum ich hier gleich von einem „Universum“ spreche, ist vielleicht noch nicht jedem klar. ggplot2 bietet seit einiger Zeit eine Schnittstelle an, mit der sich das Paket um zusätzliche Funktionalitäten erweitern lässt. Unter www.ggplot2-exts.org gibt es eine Übersicht über die bereits recht umfangreiche Sammlung.

Eine dieser Erweiterungen nennt sich „ggiraph“ und bietet HTML/JavaScript-Varianten bekannter Standardplots an. Gegenüber anderen Erweiterungen, mit denen sich so etwas auch realisieren lässt, hat ggiraph den Vorteil, dass man keine komplett neue Syntax lernen und bestehenden Code für ggplot2-Plots nur leicht anpassen muss, wenn man HTML-Widgets produzieren möchte. Für mich ist das ein schlagendes Argument für ggiraph. Ich muss allerdings zugeben, dass ich Alternativen wie metricsgraphics, plotly, highcharter oder rbokeh nicht ausprobiert habe. Wenn jemand damit Erfahrungen gesammelt hat, bin ich neugierig, davon zu hören.

glmnet-Visualisierung durch HTML-Widget ersetzen

Um unser konkretes Visualisierungsproblem zu lösen, müssen wir eine Alternative zur Funktion im Elastic-Net-Package bauen, die an Stelle eines Standardplots ein HTML-Widget erzeugt. Damit am Ende etwas herauskommt, was mit dem Original vergleichbar ist, hat es Sinn, sich an die entsprechende plot-Routine des Elastic-Net-Packages zu halten und diese in die Syntax von ggiraph zu übersetzen. Leider verwendet Elastic Net nicht ggplot2 zur Visualisierung, sondern die Basisgrafik von R, sonst wäre diese Aufgabe noch leichter. Aber auch so ist sie leicht erledigt, und es bleibt noch Zeit, ein paar kleine Zusatzfeatures einzubauen, die den Überblick weiter erleichtern. Zum einen haben wir eine Möglichkeit eingebaut, den dargestellten Wertebereich auf der y-Achse mit Hilfe eines Parameters ylimits einzuschränken. Zum anderen gibt es einen weiteren Parameter max_abs_coeff_range, der es ermöglicht, einzuschränken, welche Variablen dargestellt werden. Es werden nur solche Variablen aufgenommen, für die der maximale Betrag des Koeffizienten zwischen max_abs_coeff_range[1] und max_abs_coeff_range[2] liegt. Das kann man benutzen, um allzu unübersichtliche Plots in mehrere Teile zu zerlegen: einen für Variablen, die an irgendeiner Stelle ihres Regularisierungpfades hohe Koeffizienten aufweisen, und einen für solche, die eher kleine Koeffizienten haben. Damit vermeidet man, dass die Regularisierungspfade der letzteren Variablen sich in der Visualisierung so dicht an die x-Achse schmiegen, dass sie nicht zu unterscheiden sind. Außerdem haben wir uns noch Parameter title und ylab eingebaut, die es erlauben, die Titel der Grafik und der y-Achse zu setzen.

Wie ist der Code strukturiert?

Der wiederverwendbare Teil des Codes ist die Funktion plot_regupath. Sie lehnt sich eng an die Funktion plot.glmnet aus dem glmnet-Package an sowie an die dazugehörige interne glmnet-Funktion plotCoef. Neben den schon oben erwähnten Zusatzparametern verlangt die Funktion dieselben Parameter x (für das gefittete Modell) und xvar (für die Art der Darstellung) wie die plot-Funktion aus dem glmnet-Package. Eine ausführlichere Erläuterung dieser Parameter ist in der glmnet-Dokumentation zu finden.

library(tidyr)
library(ggplot2)
library(ggiraph)
library(htmlwidgets)
library(data.table)
library(glmnet)
plot_regupath <- function (x, xvar = c("norm", "lambda", "dev"), ylab = "Coefficient", ylimits=NULL, title=NULL, max_abs_coeff_range=NULL) 
{
    beta <- x$beta
    lambda <- x$lambda
    df <- x$df
    dev <- x$dev.ratio
    
    which = nonzeroCoef(beta)
    nwhich = length(which)
    switch(nwhich + 1, `0` = {
        warning("No plot produced since all coefficients zero")
        return()
    }, `1` = warning("1 or less nonzero coefficients; glmnet plot is not meaningful"))
    beta = as.matrix(beta[which, , drop = FALSE])
    xvar = match.arg(xvar)
    switch(xvar, norm = {
        index = apply(abs(beta), 2, sum)
        iname = "L1 Norm"
        approx.f = 1
    }, lambda = {
        index = log(lambda)
        iname = "Log Lambda"
        approx.f = 0
    }, dev = {
        index = dev
        iname = "Fraction Deviance Explained"
        approx.f = 1
    })
    
    data_for_plot <- tidyr::gather(data.frame(x_values=index, t(beta)), 
                                   key="variable", value="coefficient", 2:(nrow(beta)+1))
    
    data_for_plot <- as.data.table(data_for_plot)
    data_for_plot[, max_abs_coeff:=max(abs(coefficient)), by=variable]
    
    if (!is.null(max_abs_coeff_range)) {
        data_for_plot <- data_for_plot[max_abs_coeff > max_abs_coeff_range[1] & max_abs_coeff < max_abs_coeff_range[2]]
    }
    
    plot <- ggplot(data_for_plot, aes(x=x_values, y=coefficient, colour=variable)) + 
        # use interactive HTML-Version of a lineplot, with transparent lines, and a tooltip displaying the variable:
        geom_line_interactive(alpha=0.2, aes(data_id=variable, tooltip=variable)) +  
        guides(colour=FALSE) +              # remove the legend
        xlab(label=iname) +                 # set x-axis label
        ylab(label=ylab) +                  # set y-axis label
        theme_bw()                          # change overall appearance to a very reduced one
    
    if (!is.null(ylimits)) {
        plot <- plot + ylim(ylimits)
    }
    
    if (!is.null(title)) {
        plot <- plot + ggtitle(label=title)
    }
    
    
    result <- ggiraph::ggiraph(code={print(plot)}, tooltip_opacity=0.5, tooltip_offy = 10, height_svg=6, width_svg=12, zoom_max=100, hover_css="stroke-width:3;stroke-opacity:1")
    
    return(result)
}

c_working_directory <- ""
c_input_file_name <- "FullData.csv"
c_output_file_name <- "Regularisation_path.html"

c_predictors_to_omit <- c("Rating", "Name", "Nationality", "Club", "Contract_Expiry", "Birth_Date", "Height", "Weight", 
                          "National_Kit", "Club_Position", "Club_Kit", "Club_Joining", "Preffered_Position")

set.seed(4711)
setwd(c_working_directory)

dataset <- fread(c_input_file_name, encoding="UTF-8", sep=",", dec=".")

# exclude irrelevant predictors (and some relevant ones, to keep the dataset small)
relevant_predictors <- names(dataset)
relevant_predictors <- relevant_predictors[!relevant_predictors %in% c_predictors_to_omit] 

dataset[, above_average_ind:=as.integer(Rating>mean(Rating))] # add a 0/1-indicator whether a player is above average
dataset[, height_numeric:=as.integer(gsub("cm", "", Height))]
dataset[, weight_numeric:=as.integer(gsub("kg", "", Weight))]


prediction_formula <- as.formula(paste0("above_average_ind ~ ", paste(relevant_predictors, collapse=" + ")))

train_set_size <- floor(0.8 * nrow(dataset))
data_train <- dataset[sample(.N, train_set_size),]

model_matrix <- model.matrix(prediction_formula, data=data_train)
target <- as.factor(data_train[, above_average_ind])

glmnet_fit <- glmnet(model_matrix, target, alpha=1, family="binomial", standardize=TRUE)

html_plot <- plot_regupath(glmnet_fit, xvar="norm")
saveWidget(html_plot, c_output_file_name) # save plot to file, result is best viewed in Browser

Was sind das für Fußballdaten, die in dem Beispiel verwendet werden?

Der Rest des Codes (außer der Funktion plot_regupath) demonstriert ihre Funktion an einem Beispiel. Die Datengrundlage stammt aus dem Videospiel FIFA 2017 von Electronic Arts. Die Daten kann man bei Kaggle herunterladen:

https://www.kaggle.com/artimous/complete-fifa-2017-player-dataset-global/version/2

Sie enthalten für jeden Spieler diverse Variablen, und eine Auswahl davon benutzen wir in unserem Codebeispiel, um eine binäre Zielvariable mittels logistischer Regression vorherzusagen. Die Zielvariable haben wir aus den Daten konstruiert, indem wir einem Spieler eine 1 zugewiesen haben, wenn die Variable „Rating“ über dem Durchschnitt liegt, und 0 sonst. Wir versuchen also die besonders guten Spieler zu identifizieren. Wir reduzieren Datenaufbereitung und Feature Engineering auf ein kaum mehr wahrnehmbares Minimum, um die Funktionsweise der Visualisierung mit möglichst wenig Code demonstrieren zu können. Am Schluss wird mit saveWidget die Grafik als HTML-Datei gespeichert. Diese kann man dann im Browser anschauen (siehe auch die Abbildung). Darüberfahren mit der Maus nicht vergessen!

Beispiel HTML-Widget