Albero decisionale in R - Albero di classificazione & Codice in R con Esempio

Sommario:

Anonim

Cosa sono gli alberi decisionali?

Gli alberi decisionali sono un algoritmo di Machine Learning versatile in grado di eseguire sia attività di classificazione che di regressione. Sono algoritmi molto potenti, in grado di adattarsi a set di dati complessi. Inoltre, gli alberi decisionali sono componenti fondamentali delle foreste casuali, che sono tra i più potenti algoritmi di Machine Learning disponibili oggi.

Formazione e visualizzazione di un albero decisionale

Per costruire il tuo primo albero decisionale nell'esempio R, procederemo come segue in questo tutorial sull'albero decisionale:

  • Passaggio 1: importa i dati
  • Passaggio 2: pulire il set di dati
  • Passaggio 3: creare un set di addestramento / test
  • Passaggio 4: costruire il modello
  • Passaggio 5: fare una previsione
  • Passaggio 6: misurare le prestazioni
  • Passaggio 7: ottimizza gli iperparametri

Passaggio 1) Importa i dati

Se sei curioso del destino del Titanic, puoi guardare questo video su Youtube. Lo scopo di questo set di dati è prevedere quali persone hanno maggiori probabilità di sopravvivere dopo la collisione con l'iceberg. Il set di dati contiene 13 variabili e 1309 osservazioni. Il set di dati è ordinato dalla variabile X.

set.seed(678)path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv'titanic <-read.csv(path)head(titanic)

Produzione:

## X pclass survived name sex## 1 1 1 1 Allen, Miss. Elisabeth Walton female## 2 2 1 1 Allison, Master. Hudson Trevor male## 3 3 1 0 Allison, Miss. Helen Loraine female## 4 4 1 0 Allison, Mr. Hudson Joshua Creighton male## 5 5 1 0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female## 6 6 1 1 Anderson, Mr. Harry male## age sibsp parch ticket fare cabin embarked## 1 29.0000 0 0 24160 211.3375 B5 S## 2 0.9167 1 2 113781 151.5500 C22 C26 S## 3 2.0000 1 2 113781 151.5500 C22 C26 S## 4 30.0000 1 2 113781 151.5500 C22 C26 S## 5 25.0000 1 2 113781 151.5500 C22 C26 S## 6 48.0000 0 0 19952 26.5500 E12 S## home.dest## 1 St Louis, MO## 2 Montreal, PQ / Chesterville, ON## 3 Montreal, PQ / Chesterville, ON## 4 Montreal, PQ / Chesterville, ON## 5 Montreal, PQ / Chesterville, ON## 6 New York, NY
tail(titanic)

Produzione:

## X pclass survived name sex age sibsp## 1304 1304 3 0 Yousseff, Mr. Gerious male NA 0## 1305 1305 3 0 Zabour, Miss. Hileni female 14.5 1## 1306 1306 3 0 Zabour, Miss. Thamine female NA 1## 1307 1307 3 0 Zakarian, Mr. Mapriededer male 26.5 0## 1308 1308 3 0 Zakarian, Mr. Ortin male 27.0 0## 1309 1309 3 0 Zimmerman, Mr. Leo male 29.0 0## parch ticket fare cabin embarked home.dest## 1304 0 2627 14.4583 C## 1305 0 2665 14.4542 C## 1306 0 2665 14.4542 C## 1307 0 2656 7.2250 C## 1308 0 2670 7.2250 C## 1309 0 315082 7.8750 S

Dall'output di testa e coda, puoi notare che i dati non vengono mescolati. Questo è un grosso problema! Quando suddividerai i tuoi dati tra un set di treni e un set di test, selezionerai solo il passeggero della classe 1 e 2 (nessun passeggero della classe 3 è nell'80% superiore delle osservazioni), il che significa che l'algoritmo non vedrà mai il caratteristiche del passeggero di classe 3. Questo errore porterà a una scarsa previsione.

Per risolvere questo problema, puoi utilizzare la funzione sample ().

shuffle_index <- sample(1:nrow(titanic))head(shuffle_index)

Albero decisionale Codice R Spiegazione

  • sample (1: nrow (titanic)): genera un elenco casuale di indici da 1 a 1309 (ovvero il numero massimo di righe).

Produzione:

## [1] 288 874 1078 633 887 992 

Utilizzerai questo indice per mescolare il set di dati titanico.

titanic <- titanic[shuffle_index, ]head(titanic)

Produzione:

## X pclass survived## 288 288 1 0## 874 874 3 0## 1078 1078 3 1## 633 633 3 0## 887 887 3 1## 992 992 3 1## name sex age## 288 Sutton, Mr. Frederick male 61## 874 Humblen, Mr. Adolf Mathias Nicolai Olsen male 42## 1078 O'Driscoll, Miss. Bridget female NA## 633 Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female 39## 887 Jermyn, Miss. Annie female NA## 992 Mamee, Mr. Hanna male NA## sibsp parch ticket fare cabin embarked home.dest## 288 0 0 36963 32.3208 D50 S Haddenfield, NJ## 874 0 0 348121 7.6500 F G63 S## 1078 0 0 14311 7.7500 Q## 633 1 5 347082 31.2750 S Sweden Winnipeg, MN## 887 0 0 14313 7.7500 Q## 992 0 0 2677 7.2292 C

Passaggio 2) Pulisci il set di dati

La struttura dei dati mostra che alcune variabili hanno NA. La pulizia dei dati deve essere eseguita come segue

  • Rilascia le variabili home.dest, cabin, name, X e ticket
  • Crea variabili fattore per pclass e sopravvive
  • Lascia cadere il NA
library(dplyr)# Drop variablesclean_titanic <- titanic % > %select(-c(home.dest, cabin, name, X, ticket)) % > %#Convert to factor levelmutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')),survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > %na.omit()glimpse(clean_titanic)

Spiegazione del codice

  • select (-c (home.dest, cabin, name, X, ticket)): elimina le variabili non necessarie
  • pclass = factor (pclass, levels = c (1,2,3), labels = c ('Upper', 'Middle', 'Lower')): Aggiungi etichetta alla variabile pclass. 1 diventa Upper, 2 diventa MIddle e 3 diventa lower
  • fattore (sopravvissuto, livelli = c (0,1), etichette = c ('No', 'Sì')): Aggiungi etichetta alla variabile sopravvissuta. 1 diventa No e 2 diventa Sì
  • na.omit (): rimuove le osservazioni NA

Produzione:

## Observations: 1,045## Variables: 8## $ pclass  Upper, Lower, Lower, Upper, Middle, Upper, Middle, U… ## $ survived  No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y… ## $ sex  male, male, female, female, male, male, female, male… ## $ age  61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0,… ## $ sibsp  0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,… ## $ parch  0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,… ## $ fare  32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542,… ## $ embarked  S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C… 

Passaggio 3) Crea set di addestramento / test

Prima di addestrare il tuo modello, devi eseguire due passaggi:

  • Creare un set di treni e test: addestrare il modello sul set di treni e testare la previsione sul set di test (ovvero dati non visualizzati)
  • Installa rpart.plot dalla console

La pratica comune è quella di dividere i dati 80/20, l'80% dei dati serve per addestrare il modello e il 20% per fare previsioni. È necessario creare due frame di dati separati. Non si desidera toccare il set di prova finché non si finisce di costruire il modello. È possibile creare un nome di funzione create_train_test () che accetta tre argomenti.

create_train_test(df, size = 0.8, train = TRUE)arguments:-df: Dataset used to train the model.-size: Size of the split. By default, 0.8. Numerical value-train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) {n_row = nrow(data)total_row = size * n_rowtrain_sample < - 1: total_rowif (train == TRUE) {return (data[train_sample, ])} else {return (data[-train_sample, ])}}

Spiegazione del codice

  • function (data, size = 0.8, train = TRUE): Aggiungi gli argomenti nella funzione
  • n_row = nrow (data): conta il numero di righe nel set di dati
  • total_row = size * n_row: restituisce l'ennesima riga per costruire il treno
  • train_sample <- 1: total_row: Seleziona la prima riga all'ennesima riga
  • if (train == TRUE) {} else {}: se la condizione è impostata su true, restituisce il set di treni, altrimenti il ​​set di test.

Puoi testare la tua funzione e controllare la dimensione.

data_train <- create_train_test(clean_titanic, 0.8, train = TRUE)data_test <- create_train_test(clean_titanic, 0.8, train = FALSE)dim(data_train)

Produzione:

## [1] 836 8
dim(data_test)

Produzione:

## [1] 209 8 

Il set di dati del treno ha 1046 righe mentre il set di dati di test ha 262 righe.

Si utilizza la funzione prop.table () combinata con table () per verificare se il processo di randomizzazione è corretto.

prop.table(table(data_train$survived))

Produzione:

#### No Yes## 0.5944976 0.4055024
prop.table(table(data_test$survived))

Produzione:

#### No Yes## 0.5789474 0.4210526

In entrambi i set di dati, la quantità di sopravvissuti è la stessa, circa il 40%.

Installa rpart.plot

rpart.plot non è disponibile nelle librerie conda. Puoi installarlo dalla console:

install.packages("rpart.plot") 

Passaggio 4) Costruisci il modello

Sei pronto per costruire il modello. La sintassi per la funzione albero decisionale Rpart è:

rpart(formula, data=, method='')arguments:- formula: The function to predict- data: Specifies the data frame- method:- "class" for a classification tree- "anova" for a regression tree

Utilizzi il metodo della classe perché prevedi una classe.

library(rpart)library(rpart.plot)fit <- rpart(survived~., data = data_train, method = 'class')rpart.plot(fit, extra = 106

Spiegazione del codice

  • rpart (): funzione per adattare il modello. Gli argomenti sono:
    • sopravvissuto ~ .: Formula degli alberi decisionali
    • data = data_train: Dataset
    • method = 'class': adatta un modello binario
  • rpart.plot (fit, extra = 106): traccia l'albero. Le funzionalità extra sono impostate su 101 per visualizzare la probabilità della 2a classe (utile per le risposte binarie). Puoi fare riferimento alla vignetta per ulteriori informazioni sulle altre scelte.

Produzione:

Inizi dal nodo radice (profondità 0 su 3, la parte superiore del grafico):

  1. In alto, è la probabilità complessiva di sopravvivenza. Mostra la percentuale di passeggeri sopravvissuti all'incidente. Il 41% dei passeggeri è sopravvissuto.
  2. Questo nodo chiede se il sesso del passeggero è maschio. Se sì, allora vai al nodo figlio sinistro della radice (profondità 2). Il 63% sono maschi con una probabilità di sopravvivenza del 21%.
  3. Nel secondo nodo chiedi se il passeggero maschio ha più di 3,5 anni. Se sì, la possibilità di sopravvivenza è del 19%.
  4. Continui così per capire quali caratteristiche influenzano la probabilità di sopravvivenza.

Si noti che una delle molte qualità degli alberi decisionali è che richiedono una preparazione dei dati molto ridotta. In particolare, non richiedono il ridimensionamento o il centraggio delle caratteristiche.

Per impostazione predefinita, la funzione rpart () utilizza la misura dell'impurità di Gini per dividere la nota. Più alto è il coefficiente di Gini, più diverse sono le istanze all'interno del nodo.

Passaggio 5) Fai una previsione

Puoi prevedere il tuo set di dati di test. Per fare una previsione, puoi utilizzare la funzione Forecast (). La sintassi di base della previsione per l'albero decisionale R è:

predict(fitted_model, df, type = 'class')arguments:- fitted_model: This is the object stored after model estimation.- df: Data frame used to make the prediction- type: Type of prediction- 'class': for classification- 'prob': to compute the probability of each class- 'vector': Predict the mean response at the node level

Si desidera prevedere quali passeggeri hanno maggiori probabilità di sopravvivere dopo la collisione dal set di prova. Significa che saprai tra quei 209 passeggeri, quale sopravviverà o no.

predict_unseen <-predict(fit, data_test, type = 'class')

Spiegazione del codice

  • predire (fit, data_test, type = 'class'): prevedere la classe (0/1) del set

Testare il passeggero che non ce l'ha fatta e quelli che ce l'hanno fatta.

table_mat <- table(data_test$survived, predict_unseen)table_mat

Spiegazione del codice

  • table (data_test $ survived, Forecast_unseen): Crea una tabella per contare quanti passeggeri sono classificati come sopravvissuti e deceduti confrontandoli con la corretta classificazione dell'albero decisionale in R

Produzione:

## predict_unseen## No Yes## No 106 15## Yes 30 58

Il modello prevedeva correttamente 106 passeggeri morti ma classificava 15 sopravvissuti come morti. Per analogia, il modello ha erroneamente classificato 30 passeggeri come sopravvissuti mentre si sono rivelati morti.

Passaggio 6) Misura le prestazioni

È possibile calcolare una misura di accuratezza per l'attività di classificazione con la matrice di confusione :

La matrice di confusione è una scelta migliore per valutare le prestazioni di classificazione. L'idea generale è contare il numero di volte in cui le istanze Vere sono classificate sono False.

Ogni riga in una matrice di confusione rappresenta un obiettivo effettivo, mentre ogni colonna rappresenta un obiettivo previsto. La prima riga di questa matrice considera i passeggeri morti (la classe False): 106 sono stati correttamente classificati come morti ( Vero negativo ), mentre il rimanente è stato erroneamente classificato come sopravvissuto ( Falso positivo ). La seconda riga considera i sopravvissuti, la classe positiva era 58 ( Vero positivo ), mentre la Vero negativa era 30.

È possibile calcolare il test di accuratezza dalla matrice di confusione:

È la proporzione di vero positivo e vero negativo sulla somma della matrice. Con R, puoi codificare come segue:

accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)

Spiegazione del codice

  • sum (diag (table_mat)): somma della diagonale
  • sum (table_mat): somma della matrice.

È possibile stampare la precisione del set di prova:

print(paste('Accuracy for test', accuracy_Test))

Produzione:

## [1] "Accuracy for test 0.784688995215311" 

Hai un punteggio del 78 percento per il set di test. È possibile replicare lo stesso esercizio con il set di dati di allenamento.

Step 7) Regola gli iperparametri

L'albero decisionale in R ha vari parametri che controllano gli aspetti dell'adattamento. Nella libreria dell'albero decisionale rpart, è possibile controllare i parametri utilizzando la funzione rpart.control (). Nel codice seguente vengono introdotti i parametri che verranno sintonizzati. Puoi fare riferimento alla vignetta per altri parametri.

rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30)Arguments:-minsplit: Set the minimum number of observations in the node before the algorithm perform a split-minbucket: Set the minimum number of observations in the final note i.e. the leaf-maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0

Procederemo come segue:

  • Costruisci la funzione per restituire la precisione
  • Regola la profondità massima
  • Regola il numero minimo di campioni che un nodo deve avere prima di poter essere suddiviso
  • Regola il numero minimo di campioni che un nodo foglia deve avere

È possibile scrivere una funzione per visualizzare la precisione. Devi semplicemente racchiudere il codice che hai usato prima:

  1. predire: predicire_unseen <- predire (adattamento, test_dati, tipo = 'classe')
  2. Produci tabella: table_mat <- table (data_test $ survived, Forecast_unseen)
  3. Precisione di calcolo: accuratezza_test <- sum (diag (table_mat)) / sum (table_mat)
accuracy_tune <- function(fit) {predict_unseen <- predict(fit, data_test, type = 'class')table_mat <- table(data_test$survived, predict_unseen)accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)accuracy_Test}

Puoi provare a regolare i parametri e vedere se riesci a migliorare il modello rispetto al valore predefinito. Come promemoria, è necessario ottenere una precisione superiore a 0,78

control <- rpart.control(minsplit = 4,minbucket = round(5 / 3),maxdepth = 3,cp = 0)tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control)accuracy_tune(tune_fit)

Produzione:

## [1] 0.7990431 

Con il seguente parametro:

minsplit = 4minbucket= round(5/3)maxdepth = 3cp=0 

Ottieni prestazioni più elevate rispetto al modello precedente. Congratulazioni!

Sommario

Possiamo riassumere le funzioni per addestrare un algoritmo ad albero decisionale in R

Biblioteca

Obbiettivo

funzione

classe

parametri

dettagli

rpart

Albero di classificazione dei treni in R

rpart ()

classe

formula, df, metodo

rpart

Addestra albero di regressione

rpart ()

anova

formula, df, metodo

rpart

Traccia gli alberi

rpart.plot ()

modello montato

base

prevedere

prevedere ()

classe

modello montato, tipo

base

prevedere

prevedere ()

prob

modello montato, tipo

base

prevedere

prevedere ()

vettore

modello montato, tipo

rpart

Parametri di controllo

rpart.control ()

minsplit

Imposta il numero minimo di osservazioni nel nodo prima che l'algoritmo esegua una divisione

minbucket

Imposta il numero minimo di osservazioni nella nota finale, cioè la foglia

profondità massima

Imposta la profondità massima di qualsiasi nodo dell'albero finale. Il nodo radice viene trattato con una profondità 0

rpart

Modello di treno con parametro di controllo

rpart ()

formula, df, metodo, controllo

Nota: addestrare il modello su dati di addestramento e testare le prestazioni su un set di dati non visualizzato, ad esempio un set di test.