8 Árboles
El método del árbol de decisión es una técnica de aprendizaje automático predictivo potente y popular que se utiliza tanto para la clasificación como para la regresión. También se conoce como árboles de clasificación y regresión (CART).
La implementación de R del algoritmo CART se llama RPART
(árboles de regresión y particionamiento recursivos) disponible en un paquete con el mismo nombre.
8.1 Carga de los paquetes R
install.packages("tidyverse")
install.packages("caret")
install.packages("rpart")
install.packages("rpart.plot")
8.2 Árboles de clasificación
Conjunto de datos: PimaIndiansDiabetes2 [en paquete mlbench], para predecir la probabilidad de ser diabético positivo basándose en múltiples variables clínicas.
Los datos contienen 768 individuos (mujeres) y 9 variables clínicas para predecir la probabilidad de que los individuos sean diabéticos positivos o negativos:
embarazadas: número de embarazos,
glucosa: concentración de glucosa plasmática,
presión: presión arterial diastólica (mm Hg),
tríceps: espesor del pliegue cutáneo del tríceps (mm),
insulina: insulina sérica de 2 horas (mu U / ml),
mass: índice de masa corporal (peso en kg / (altura en m) ) ^ 2),
pedigree: diabetes función del pedigree,
edad: edad (años),
diabetes: variable de clase positiva o negativa.
En R se utiliza la función
rpart
con el argumento:diabetes ~ .
donde diabetes es la variable de interés, que en este caso en categorica,~
significa que sigue los datos y.
significa que toma todas las variables de la base.*
library(tidyverse)
library(caret)
library("rpart.plot")
library(rpart)
library("mlbench")
data ("PimaIndiansDiabetes2" , package = "mlbench")
<-rpart(diabetes ~ ., data = PimaIndiansDiabetes2, method = "class" )
model1 plot(model1)
text(model1, digits = 3 )
R también nos permite ver el conjunto de reglas que define el árbol, en el caso de no poderlo detallar.
data ( "PimaIndiansDiabetes2" , package = "mlbench" )
<-rpart (diabetes ~ ., data = PimaIndiansDiabetes2, method = "class" )
model1 print(model1)
## n= 768
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 768 268 neg (0.65104167 0.34895833)
## 2) glucose< 127.5 485 94 neg (0.80618557 0.19381443)
## 4) age< 28.5 271 23 neg (0.91512915 0.08487085) *
## 5) age>=28.5 214 71 neg (0.66822430 0.33177570)
## 10) insulin< 142.5 164 48 neg (0.70731707 0.29268293)
## 20) glucose< 96.5 51 4 neg (0.92156863 0.07843137) *
## 21) glucose>=96.5 113 44 neg (0.61061947 0.38938053)
## 42) mass< 26.35 19 0 neg (1.00000000 0.00000000) *
## 43) mass>=26.35 94 44 neg (0.53191489 0.46808511)
## 86) pregnant< 5.5 49 15 neg (0.69387755 0.30612245)
## 172) age< 34.5 25 2 neg (0.92000000 0.08000000) *
## 173) age>=34.5 24 11 pos (0.45833333 0.54166667)
## 346) pressure>=77 10 2 neg (0.80000000 0.20000000) *
## 347) pressure< 77 14 3 pos (0.21428571 0.78571429) *
## 87) pregnant>=5.5 45 16 pos (0.35555556 0.64444444) *
## 11) insulin>=142.5 50 23 neg (0.54000000 0.46000000)
## 22) age>=56.5 12 1 neg (0.91666667 0.08333333) *
## 23) age< 56.5 38 16 pos (0.42105263 0.57894737)
## 46) age>=33.5 29 14 neg (0.51724138 0.48275862)
## 92) triceps>=27 22 8 neg (0.63636364 0.36363636) *
## 93) triceps< 27 7 1 pos (0.14285714 0.85714286) *
## 47) age< 33.5 9 1 pos (0.11111111 0.88888889) *
## 3) glucose>=127.5 283 109 pos (0.38515901 0.61484099)
## 6) mass< 29.95 75 24 neg (0.68000000 0.32000000) *
## 7) mass>=29.95 208 58 pos (0.27884615 0.72115385)
## 14) glucose< 157.5 116 46 pos (0.39655172 0.60344828)
## 28) age< 30.5 50 23 neg (0.54000000 0.46000000)
## 56) pressure>=73 29 10 neg (0.65517241 0.34482759)
## 112) mass< 41.8 20 4 neg (0.80000000 0.20000000) *
## 113) mass>=41.8 9 3 pos (0.33333333 0.66666667) *
## 57) pressure< 73 21 8 pos (0.38095238 0.61904762) *
## 29) age>=30.5 66 19 pos (0.28787879 0.71212121) *
## 15) glucose>=157.5 92 12 pos (0.13043478 0.86956522) *
8.2.1 Predicciones
Los diferentes conjuntos de reglas establecidos en el árbol se utilizan para predecir el resultado de una nueva prueba de datos. El siguiente código R predice con los siguientes datos:
<-data.frame(glucose = 98 ,insulin = 148, age=30, pregnant=6, triceps=29, mass=30, pedigree=0.8, pressure=70)
NuevosDatos %>% predict(NuevosDatos) model1
## neg pos
## 1 0.1111111 0.8888889
Entonces el resultado será:
%>% predict(NuevosDatos)
model1
neg pos1 0.1111111 0.8888889
es decir, el paciente tiene diabetes con una probabilidad de 0.888.
Si por ejemplo, solo quiere utilizar las variables age
y glucose
se pone en rpart
diabetes ~ age+glucose
y solo da un árbol con solo estas variables.
data ( "PimaIndiansDiabetes2" , package = "mlbench" )
<-rpart(diabetes ~ age+glucose, data = PimaIndiansDiabetes2, method = "class" )
model2 plot(model2)
text(model2, digits = 3 )
Si el árbol sale feo y no se ve tan chévere, recomiendo el parquete rpart.plot
:
library(rpart.plot)
rpart.plot(model2, box.palette="RdBu", shadow.col="gray", nn=TRUE)
8.2.2 Datos del iris de Edgar Anderson
Este famoso conjunto de datos de iris (de Fisher o Anderson) da las medidas en centímetros de las variables longitud y ancho del sépalo y largo y ancho del pétalo, respectivamente, para 50 flores de cada una de las 3 especies de iris. Las especies son Iris setosa, versicolor y virginica.
iris
es un marco de datos con 150 casos (filas) y 5 variables (columnas) denominadas Sepal.Length, Sepal.Width, Petal.Length, Petal.Width y Species.
El siguiente ejemplo representa un modelo de árbol que predice la especie de flor de iris según la longitud (en cm) y el ancho del sépalo y el pétalo
library(tidyverse)
library(caret)
library(rpart)
data ( "iris" )
<-rpart(Species ~ ., data = iris, method = "class")
model par( xpd = NA ) # ayuda a ajustar el gráfico.
rpart.plot(model)
text (model, digits = 3 )
8.3 Árboles de regresión
El código R es idéntico al que hemos visto en apartados anteriores. La única diferencia es que al final del código es method = "anova"
. No es casualidad, lo que hace es comparar todas las variables con todas.
Ejemplo de conjunto de datos: Usaremos el conjunto de datos de Boston [en el paquete MASS], para predecir el valor mediano de la casa (mdev), en los suburbios de Boston, usando diferentes variables predictoras.
Cargue los datos de datos (“Boston,” paquete = “MASS”). Este marco de datos contiene las siguientes columnas:
crim: tasa de criminalidad per cápita por ciudad.
zn: proporción de terreno residencial dividido en zonas para lotes de más de 25,000 pies cuadrados.
indus: proporción de acres comerciales no minoristas por ciudad.
chas: Variable ficticia de Charles River (= 1 si el tramo limita con el río; 0 en caso contrario).
nox: concentración de óxidos de nitrógeno (partes por 10 millones).
rm: número medio de habitaciones por vivienda.
age: Proporción de unidades ocupadas por sus propietarios construidas antes de 1940.
dis: media ponderada de las distancias a cinco centros de empleo de Boston.
rad: índice de accesibilidad a carreteras radiales.
tax: Tasa de impuesto a la propiedad de valor total por $ 10,000.
ptratio: Proporción alumno-profesor por ciudad.
black: 1000 (Bk - 0.63) ^ 2 donde Bk es la proporción de negros por ciudad. (ESTA VARIABLE ES RACISTA)
lstat: estatus más bajo de la población (porcentaje).
medv: valor medio de las viviendas ocupadas por sus propietarios en $ 1000.
data( "Boston" , package = "MASS" )
<-rpart(medv ~ ., data = Boston, method = "anova")
model rpart.plot(model)
text (model, digits = 3 )
Elimine algunas variables, como la racista. Solo considere crim
, nox
y dis
y ejecute el árbol:
data( "Boston" , package = "MASS" )
<-rpart(medv ~ crim+nox+dis, data = Boston, method = "anova")
model rpart.plot(model)
text (model, digits = 3 )