Pour réaliser leurs profit, les compagnies d’assurance doivent percevoir une prime plus élevée que le montant versé à l’assuré. Pour cette raison, les compagnies d’assurance investissent beaucoup de temps, d’efforts et d’argent dans la création de modèles qui prédit avec précision les coûts des soins de santé.
Afin de remplir cette mission, nous allons dans un premier lieu analyser les facteurs qui influencent les charges médicaux et dans un second lieu essayer de construire un modèle adéquat et optimiser ses performances.
insurance_data <- read.csv("insurance.csv")
library(ggthemes)
library(tidyverse)
library(ggridges)
library(magrittr)
library(gridExtra)
library(patchwork)
library(inspectdf)
library(corrplot)
library(tidymodels)
library(car)
library(caret)
theme_set(theme_bw())
glimpse(insurance_data)
## Rows: 1,338
## Columns: 7
## $ age <int> 19, 18, 28, 33, 32, 31, 46, 37, 37, 60, 25, 62, 23, 56, 27...
## $ sex <chr> "female", "male", "male", "male", "male", "female", "femal...
## $ bmi <dbl> 27.900, 33.770, 33.000, 22.705, 28.880, 25.740, 33.440, 27...
## $ children <int> 0, 1, 3, 0, 0, 0, 1, 3, 2, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0...
## $ smoker <chr> "yes", "no", "no", "no", "no", "no", "no", "no", "no", "no...
## $ region <chr> "southwest", "southeast", "southeast", "northwest", "north...
## $ charges <dbl> 16884.924, 1725.552, 4449.462, 21984.471, 3866.855, 3756.6...
De nombreux facteurs qui influencent le montant qu’un assuré paye pour l’assurance médicale ne sont pas sous son contrôle. Néanmoins, il est bon d’avoir une compréhension de ce qu’ils sont. Voici quelques facteurs collectées par l’assurance, qu’on va étudier leurs influence sur le coût des primes d’assurance médicale :
On dispose d’un dataset qui comporte 1338 observations sur 7 variables
On va convertir les variables sex, children, region, smoker au type factor qui correspond aux variables catégorielles :
insurance_data$sex %<>% as.factor()
insurance_data$children %<>% as.factor()
insurance_data$region %<>% as.factor()
insurance_data$smoker %<>% as.factor()
summary(insurance_data)
## age sex bmi children smoker
## Min. :18.00 female:662 Min. :15.96 0:574 no :1064
## 1st Qu.:27.00 male :676 1st Qu.:26.30 1:324 yes: 274
## Median :39.00 Median :30.40 2:240
## Mean :39.21 Mean :30.66 3:157
## 3rd Qu.:51.00 3rd Qu.:34.69 4: 25
## Max. :64.00 Max. :53.13 5: 18
## region charges
## northeast:324 Min. : 1122
## northwest:325 1st Qu.: 4740
## southeast:364 Median : 9382
## southwest:325 Mean :13270
## 3rd Qu.:16640
## Max. :63770
insurance_data %<>% mutate(bmi_cat = cut(bmi,
breaks = c(0, 18.5, 25, 30, 60),
labels = c("Under Weight", "Normal Weight", "Overweight", "Obese")
))
categ_cols <- insurance_data %>% select_if(~ class(.) == "factor")
for (col in names(categ_cols)) {
# cat(col , ":")
# print(round(prop.table(table(insurance_data[[col]])),digits=2))
# cat("\n")
t <- insurance_data %>%
group_by_(col) %>%
summarise(count = n()) %>%
mutate(frequency = paste0(round(100 * count / sum(count), 0), "%")) %>%
# gt::gt() %>%
knitr::kable("html", align = "lcc") %>%
kableExtra::kable_styling(full_width = F, position = "left") %>%
print()
}
sex | count | frequency |
---|---|---|
female | 662 | 49% |
male | 676 | 51% |
children | count | frequency |
---|---|---|
0 | 574 | 43% |
1 | 324 | 24% |
2 | 240 | 18% |
3 | 157 | 12% |
4 | 25 | 2% |
5 | 18 | 1% |
smoker | count | frequency |
---|---|---|
no | 1064 | 80% |
yes | 274 | 20% |
region | count | frequency |
---|---|---|
northeast | 324 | 24% |
northwest | 325 | 24% |
southeast | 364 | 27% |
southwest | 325 | 24% |
bmi_cat | count | frequency |
---|---|---|
Under Weight | 21 | 2% |
Normal Weight | 226 | 17% |
Overweight | 386 | 29% |
Obese | 705 | 53% |
Observations :
- 50% des individus ont au plus 1 enfant, 75% ont au plus 2.
- la domination d’obésité sur l’échantillon étudié.
- 80% des individus sont des fumeurs.
- le jeu de données est équilibrée par rapport au sex et region.
show_plot(inspect_cat(insurance_data))
Remarque : Les individus ayant plus que 3 enfants et ceux en sous poids sont négligeables dans notre échantillon donc on va pas les prendre en compte pour la comparaison.
count missing | |
---|---|
age | 0 |
sex | 0 |
bmi | 0 |
children | 0 |
smoker | 0 |
region | 0 |
charges | 0 |
bmi_cat | 0 |
le jeu de données ne contient aucune valeur manquante !
Les variables les plus corrélés avec les charges sont “smoker”, “age” et “bmi”.
#CPCOLS <- c("#617A8F", "#000000", "#B57779", "#FAFAFA", "#FFFFFF", "#FFFFFF")
# continuous:
# upper$continuous with lower$continuous String: ‘points’, ‘smooth’, ‘density’, ‘cor’, ‘blank’
# diag$continuous String: "densityDiag", "barDiag", "blankDiag"
# combo:
# upper$combo with lower$combo String: ‘box’, ‘dot’, ‘facethist’, ‘facetdensity’, ‘denstrip’, ‘blank’
# discrete:
# upper$discrete with lower$discrete String: ‘ratio’, ‘facetbar’, ‘blank’
# diag$discrete String: ‘barDiag’, ‘blankDiag’
CPCOLS <- c("#5B8888", "#bfd1d9", "#B57779", "#F4EDCA")
insurance_data %>%
dplyr::select(age, bmi, smoker, charges) %>%
GGally::ggpairs(
lower = list(
continuous = GGally::wrap("points", col = "#4c86ad",alpha=0.6),
combo = GGally::wrap("box",
fill = "white", col ="black"
)
),
upper = list(
continuous = GGally::wrap("cor", col = "#4c86ad"),
combo = GGally::wrap("facetdensity",
col =
"black"
)
),
diag = list(
continuous = GGally::wrap(
"barDiag",
fill = "#f5dfb3", col ="black",
bins = 18
),
discrete = GGally::wrap("barDiag", fill = "#f5dfb3", col ="black")
)
)
Observations :
p1 <- ggplot(data = insurance_data, aes(x = charges)) +
geom_histogram(aes(y = ..density..),
col = "white",
bins = 25,
fill = "#90b8c2"
) +
geom_density() +
labs(title = "Distribution des charges") +
theme(
axis.text.y = element_blank(),
axis.ticks.y = element_blank()
)
p2 <- ggplot(data = insurance_data, aes(y = charges)) +
geom_boxplot(fill = "#90b8c2") +
ggtitle("Charges Boxplot") +
theme(
axis.ticks.x = element_blank(),
axis.text.x = element_blank()
) +
geom_hline(aes(yintercept = mean(charges)),
linetype = "dashed", col =
"red"
)
grid.arrange(p1, p2,
ncol = 2, widths = c(5, 2)
)
La distribution des charges est biaisée à droite, on ne peut pas juger les valeurs aberrantes à partir de ce box plot. Donc on va effectuer une transformation logarithmique pour centraliser la distribution.
p1 <- ggplot(data = insurance_data, aes(x = log(charges))) +
geom_histogram(aes(y = ..density..),
col = "white",
bins = 25,
fill = "#90b8c2"
) +
geom_density() +
labs(title = "Distribution des charges") +
theme(
axis.text.y = element_blank(),
axis.ticks.y = element_blank()
)
p2 <- ggplot(data = insurance_data, aes(y = log(charges))) +
geom_boxplot(fill = "#90b8c2") +
ggtitle("Charges Boxplot") +
theme(
axis.text.y = element_blank(),
axis.ticks.x = element_blank(),
axis.text.x = element_blank()
) +
geom_hline(aes(yintercept = mean(log(charges))),
linetype = "dashed", col =
"red"
)
grid.arrange(p1, p2,
ncol = 2, widths = c(5, 2)
)
On peut maintenant assumer qu’il n’y a aucune anomalie.
p1 <- insurance_data %>%
ggplot(aes(x = age, y = charges, col = bmi_cat)) +
geom_point(alpha = 0.6, size = 2.5)
p2 <- insurance_data %>%
ggplot(aes(x = age, y = charges, col = smoker)) +
geom_point(alpha = 0.8,size = 2.5) +
scale_color_manual(values = c("#e09e8f", "#90b8c2")) +
geom_rug() +
geom_smooth() +
geom_smooth(
data = filter(insurance_data, smoker == "yes"),
col = "grey30",
method = lm,
se = FALSE
) +
geom_smooth(
data = filter(insurance_data, smoker == "no"),
col = "grey30",
method = lm,
se = FALSE
)
grid.arrange(p1, p2, nrow = 1)
Les charges sont liées à l’âge par une relation quasiment linéaire à trois niveaux :
On peut également constater que -pour les trois niveaux- plus les clients sont âgés, plus leurs charges sont élevés.
p1 <- insurance_data %>%
ggplot(aes(x = bmi, y = charges, col = smoker)) +
geom_point(size = 2.5, alpha = 0.8) +
scale_color_manual(values = c("#6f9fb3", "#79a346")) +
theme(legend.position = c(0.08, 0.85)) +
geom_smooth(
data = filter(insurance_data, smoker == "yes"),
col = "blue"
) +
geom_smooth(
data = filter(insurance_data, smoker == "yes"),
col = "black",
method = lm,
se = FALSE,
lty = "dashed"
) +
geom_vline(aes(xintercept = 30), linetype = "dashed") +
geom_abline(aes(intercept = 12800, slope = 615), col = "red")
p2 <- insurance_data %>%
ggplot(aes(
x = cut(bmi, breaks = 8),
y = charges,
col = smoker
)) +
geom_boxplot() +
scale_color_manual(values = c("#6f9fb3", "#79a346")) +
labs(x = "intervalles de bmi") +
guides(col = FALSE) +
labs(x = element_blank(), y = "\n\n") +
theme(axis.text.x = element_text(
size = 8,
angle = 40,
hjust = 1
)) +
theme(
axis.text.y = element_blank(),
axis.ticks.y = element_blank()
)
p3 <- insurance_data %>%
filter(smoker == "yes" & bmi_cat != "Under Weight") %>%
ggplot(aes(x = charges, fill = bmi_cat)) +
geom_density(alpha = 0.3) +
scale_color_manual(values = c("#e09e8f", "#90b8c2","#6f9fb3")) +
theme(
axis.text.y = element_blank(),
axis.ticks.y = element_blank()
) +
theme(legend.position = c(0.85, 0.68)) +
labs(y = "\ndensity\n")
p4 <- insurance_data %>%
filter(smoker == "no" & bmi_cat != "Under Weight") %>%
ggplot(aes(x = charges, fill = bmi_cat)) +
geom_density(alpha = 0.3) +
theme(
axis.text.y = element_blank(),
axis.ticks.y = element_blank()
) +
guides(fill = FALSE) +
labs(y = "\n\n")
gridExtra::grid.arrange(p1, p2, p3, p4, nrow = 2, heights = c(5, 3))
Pour les non fumeurs, l’obesité n’a aucun impact sur les charges.
Par contre les fumeurs créent presque un nouveau nuage de points séparée des non-fumeurs qui se caractérise par la croissance des charges en fonction du bmi. ce nuage s’éleve fortement lorsque le seuil d’obesité est franchi (bmi=30): Le seuil maximal des charges des non obèses est identique au seuil minimal des charges des obèses.
p1 <- insurance_data %>%
ggplot(aes(x = bmi, y = charges, col = smoker)) +
geom_point(size = 3, alpha = 0.8) +
scale_color_manual(values = c("#b36b56", "#6f9fb3")) +
geom_density_2d() +
stat_density_2d(aes(fill = ..level..), geom = "polygon", alpha = 0.35) +
labs(fill = "density\n") +
scale_fill_gradient(low = "white", high = "black") +
theme(
legend.position = c(0.07, 0.69),
legend.text = element_blank()
)
ggExtra::ggMarginal(p1, groupColour = TRUE, groupFill = TRUE)
# ggExtra::ggMarginal(p1,type = "histogram", fill = "grey", col = "black")
On peut distinguer trois centres de densité :
Le premier englobe tous les individus non fumeurs et se caractérise par des charges faible et une répartition normale de bmi. Le 2ème regroupe les individus fumeurs non obèses et ayant des charges relativement élevés.
et le dernier rassemble les individus fumeurs et obèses ayant des charges très élevés.
On va créer une variable “health” qui distingue chacun de ces groupes observés :
- “Obese smokers”
- “Non obese smokers”
- “Non smokers”
insurance_data %<>%
mutate(
health = case_when(
smoker == "yes" & bmi_cat == "Obese" ~ "Obese and smoker",
smoker == "yes" & bmi_cat != "Obese" ~ "Non obese and smoker",
smoker == "no" ~ "Non smoker"
)
)
insurance_data$health %<>% as.factor()
p1 <- insurance_data %>%
ggplot(aes(x = charges, fill = health)) +
geom_density(alpha = 0.3) +
theme(legend.position = c(0.81, 0.87)) +
theme(
axis.text.y = element_blank(),
axis.ticks.y = element_blank()
) +
labs(y = "density\n")
p2 <- insurance_data %>%
ggplot(aes(y = charges, x = age, col = health)) +
geom_point() +
scale_color_manual(values = c("#d6907c","#93c754", "#77a2b5")) +
geom_smooth() +
geom_smooth(
aes(group = health),
col = "black",
method = lm,
se = FALSE,
lty = "dashed"
) +
guides(col = FALSE)
p3 <- insurance_data %>%
ggplot(aes(x = bmi, y = charges, col = health)) +
geom_point(size = 3, alpha = 0.5) +
scale_color_manual(values = c("#d6907c","#93c754", "#77a2b5")) +
stat_ellipse() +
geom_rug() +
guides(col = FALSE)
p4 <- insurance_data %>%
ggplot(aes(x = age, y = charges, col = health)) +
geom_point(
size = 4,
alpha = 0.9,
shape = "O"
) +
scale_color_manual(values = c("#d6907c","#93c754", "#77a2b5")) +
stat_ellipse() +
geom_rug() +
guides(col = FALSE) +
theme(
axis.text.y = element_blank(),
axis.ticks.y = element_blank()
) +
labs(y = "\n\n")
grid.arrange(p1, p2, p3, p4, nrow = 2)
On peut maintenant constater que les individus obèses et non fumeurs sont moins chargés que les individus non obèses mais fumeurs, En effet ces derniers sont tous situés dans le deuxième niveau du nuage de points des charges en fonction de l’âge. l’impact d’obesité sur les charges est plus important que l’impact de fumer.
Par contre les individus non fumeurs ( obèses et non obèses ) sont majoritairement situés dans le premier niveau mais certains sont dispersé dans le deuxième ce qui nous pousse a chercher le facteur qui cause cette difference.
On conclut alors que le facteur d’obesité sépare les individus fumeurs en deux groupes, un groupe obèse qui est significativement plus chargé que le groupe non obèse, mais ne nous permet pas de tirer une conclusion a propos des individus non fumeurs.
p1 <- ggplot(data = insurance_data, aes(sex, charges)) +
geom_boxplot(fill = c("#cfe6b8", "#8bb054")) +
labs(x = element_blank()) +
ggtitle("Boxplot of Medical Charges by Gender")
p2 <- ggplot(insurance_data, aes(reorder(sex, charges), charges)) +
geom_bar(
fill = c("#cfe6b8", "#8bb054"),
col="black",
position = "dodge",
stat = "summary",
fun = "mean"
) +
labs(x = element_blank(), y = element_blank()) +
ggtitle("Barplot of Medical Charges by Gender")
p1 + p2
On remarque que les frais médicaux sont plus élevés chez les hommes par rapport aux femmes
On se demande si la majorité des fumeurs sont des hommes ?
p1 <- insurance_data %>%
ggplot(aes(x = smoker, fill = sex)) +
geom_bar(position = "dodge") +
theme(legend.position = c(0.8, 0.88)) +
ggtitle("Count of patients by Gender")+
scale_fill_manual(values = c("#e09e8f", "#90b8c2"))
p2 <- insurance_data %>%
ggplot(aes(sex, charges, fill = smoker)) +
geom_bar(
position = "dodge",
stat = "summary",
fun = "mean"
) +
labs(x = element_blank(), y = "\n\n\ncharges") +
ggtitle("Barplot of Medical Charges by Gender")+
scale_fill_manual(values = c("#e09e8f", "#90b8c2"))
p3 <- insurance_data %>%
ggplot(aes(y = charges, x = smoker, fill = smoker)) +
geom_violin(width = 1.7) +
geom_boxplot(width = 0.28) +
guides(fill = FALSE) +
labs(x = element_blank(), y = element_blank()) +
ggtitle("Boxplot of Medical Charges by Gender") +
scale_fill_manual(values = c("#e09e8f", "#90b8c2")) +
facet_wrap(~sex)
p1 + p2 + p3
On peut maintenant confirmer notre supposition.
et en comparent les feumeurs des deux genres, on trouve que la différence des frais médicaux entres les hommes et les femmes peut être expliquée par deux facteurs :
p1 <- ggplot(insurance_data, aes(charges,col = region),col = c("#a6bbff" ,"#ffc99c" ,"#f5dfb3", "#e6a1bc")) +
geom_density( fill = "grey", alpha = 0.05) +
ggtitle("Density of Medical Charges per Region") +
theme(legend.position = c(0.82, 0.82)) +
theme(axis.text.y = element_blank()) +
labs(y = "density\n")
p2 <- ggplot(insurance_data, aes(reorder(region, charges), charges)) +
geom_bar(
fill = c("#c2aebb" , "#f5dfb3","#e09e8f","#90b8c2"),
position = "dodge",
stat = "summary",
fun = "mean"
) +
labs(x = "", y = "\ncharges") +
ggtitle("Average of Medical Charges per Region")
p1 + p2
On constate que les frais médicaux sont également distribués sur les régions avec une légère augmentation dans le sud-est. Cherchons la cause dérrière cette différence!
cat("% de chaque chaque categorie de santé par région :\n")
## % de chaque chaque categorie de santé par région :
round(100 * prop.table(table(
insurance_data$health,
insurance_data$region
), margin = 2), digits = 2)
##
## northeast northwest southeast southwest
## Non obese and smoker 11.73 10.77 9.07 7.38
## Non smoker 79.32 82.15 75.00 82.15
## Obese and smoker 8.95 7.08 15.93 10.46
cat("\n\n % des fumeurs par région :\n")
##
##
## % des fumeurs par région :
round(100 * prop.table(table(
insurance_data$smoker,
insurance_data$region
), margin = 2), digits = 2)
##
## northeast northwest southeast southwest
## no 79.32 82.15 75.00 82.15
## yes 20.68 17.85 25.00 17.85
cat("\n\n % de chaque chaque categorie d'obésité par région :\n")
##
##
## % de chaque chaque categorie d'obésité par région :
round(100 * prop.table(
table(insurance_data$bmi_cat,
insurance_data$region),
margin = 2
), digits = 2)
##
## northeast northwest southeast southwest
## Under Weight 3.09 2.15 0.00 1.23
## Normal Weight 22.53 19.38 11.26 15.08
## Overweight 30.25 32.92 21.98 31.08
## Obese 44.14 45.54 66.76 52.62
ggpubr::ggballoonplot(as.data.frame(table(dplyr::select(
insurance_data, region, bmi_cat
))),
fill = "value") +
scale_fill_viridis_c(option = "D")
library(FactoMineR)
library(factoextra)
res.ca <-
CA(table(dplyr::select(insurance_data, region, health)), graph = FALSE)
fviz_ca_biplot(res.ca, repel = TRUE)
Grâce aux tableaux de contingence et au graphique ci-dessus, on peut éxpliquer les charges élevés dans le sud-est par la présence d’une proportion importante des patients fumeurs et obèses par rapport aux autres régions.
et à l’aide du graphe de l’AFC on peut visualiser la dépendance entre les modalités de “health” et “region”, on voit du premier axe qui représente 83.5% de l’information que la région “southeast” se distingue du reste par une dépendance remarquable avec la classe “Obese and smoker”.
p1 <- ggplot(filter(insurance_data, children %in% c(0:3))) +
# geom_density(aes(charges,fill = children), alpha = 0.4) +
# theme(legend.position = c(0.85, 0.78)) +
# theme(axis.text.y=element_blank(), axis.ticks.y=element_blank())+
# labs(fill="children\n")
geom_density_ridges(
aes(
x = charges,
y = children,
fill = children,
height = stat(density)
),
alpha=0.6,
stat = "density",
scale = 2.25
) +
# alpha = 1,rel_min_height = 0.002, quantile_lines = TRUE) +
scale_fill_manual(values = c("#f5dfb3","#e09e8f","#b36b56","#8bbfaf")) +
# scale_x_continuous(expand = c(0, 0)) +
scale_y_discrete(expand = expand_scale(mult = c(0.07, .7))) +
coord_cartesian(clip = "off") +
theme(
legend.position = c(0.88, 0.82),
axis.text.y = element_blank(),
axis.ticks.y = element_blank()
) +
# theme_ridges(font_size = 13, grid = TRUE)+
labs(y = "density")
p2 <- insurance_data %>%
ggplot(aes(reorder(children, charges), charges)) +
geom_bar(
position = "dodge",
stat = "summary",
fun = "mean",
fill = c("#f5dfb3","#e09e8f","#b36b56","#8bbfaf")
) +
scale_x_discrete(limits = factor(c(0:3))) +
labs(x = "children", y = "\n\ncharges\n",title="Average of medical charges by number of children")
p1 + p2
Les charges sont plus élevés chez les parents.
Les charges des non parents se caractérisent par une densité particulière.
On va visualiser cette différence avec une nouvelle variable “parent” de deux modalités “yes” et “no”.
insurance_data$parent <-
as.factor(ifelse(insurance_data$children == 0, "no", "yes"))
p1 <- insurance_data %>%
ggplot(aes(y = charges, x = parent, fill = parent)) +
geom_violin(width = 0.8,alpha=0.7) +
geom_boxplot(col = "grey30", alpha = 0.6, width = 0.13) +
theme(legend.position = c(0.9, 0.9)) +
labs(x = element_blank(), y = element_blank()) +
scale_fill_manual(values = c("#a7c76c","#d19bbb"))
p <- insurance_data %>%
ggplot() +
geom_point(aes(x = age, y = charges, col = parent),
size = 1,
alpha = 0.8) +
scale_color_manual(values=c("#79a346","#d19bbb"))+
theme(axis.text.y = element_blank(),
axis.ticks.y = element_blank())
p2 <- ggExtra::ggMarginal(p + theme(legend.position = "none"),
groupColour = TRUE,
groupFill = TRUE
)
grid.arrange(p1, p2,
ncol = 2, widths = c(2, 3)
)
La distinction entre les charges des parents et non parents est plus significatif dans le premier niveau.
La majorité des individus âgés de 30 à 50 ans sont parents, alors que les jeunes et vieillards ne sont pas parents d’enfants couverts par l’assurance.
ad <- 0.9
p1 <- insurance_data %>%
filter(parent == "yes") %>%
ggplot(aes(x = age, y = charges, col = health)) +
scale_color_manual(values = c("#79a346", "#245373","#964a74"))+
geom_point(size = 0.8, alpha = 0.7) +
stat_density_2d(
aes(fill = ..level..),
geom = "polygon",
alpha = 0.5,
adjust = ad
) +
labs(fill = "densité\n") +
scale_fill_gradient(low = "white", high = "black") +
guides(fill = FALSE, col = FALSE)
p3 <- ggExtra::ggMarginal(p1, groupColour = TRUE, groupFill = TRUE)
p2 <- insurance_data %>%
filter(parent == "no") %>%
ggplot(aes(x = age, y = charges, col = health)) +
geom_point(size = 0.8, alpha = 0.7) +
scale_color_manual(values = c("#79a346", "#245373","#964a74")) +
stat_density_2d(
aes(fill = ..level..),
geom = "polygon",
alpha = 0.5,
adjust = ad
) +
labs(y = "\n") +
scale_fill_gradient(low = "white", high = "black") +
guides(fill = FALSE, col = guide_legend(order=1))+
theme(legend.position = "left")
p4 <- ggExtra::ggMarginal(p2, groupColour = TRUE, groupFill = TRUE)
grid.arrange(p3, p4, nrow = 1, widths = c(6, 9))
Due à la relation entre l’âge et les charges :
Les parents ont des charges centré sur la moyenne de chaque niveau, contrairement aux non parents qui rassemblent à la fois des charges élevés (des plus agés) et des charges bas (des moins agés)
Cette différence est moins remarquable quand les charges sont élevés à cause de la dominations des charges individuels sur les charges des enfants qui deviennent négligeables.
On va commencer par un modèle linéaire simple en utilisant les deux variables age et health. Centraliser la variables âge (pour permettre l’interpretation de l’intercept).
insurance_data %<>% mutate(centred_age = age - 18)
insurance_data$health = relevel(insurance_data$health, ref = "Non smoker")
mod.basic = lm(charges ~ centred_age + health,
data = insurance_data,
x = TRUE,
y = TRUE)
summary(mod.basic)
##
## Call:
## lm(formula = charges ~ centred_age + health, data = insurance_data,
## x = TRUE, y = TRUE)
##
## Residuals:
## Min 1Q Median 3Q Max
## -5585.7 -1953.5 -1324.8 -394.4 24493.8
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 2693.654 234.102 11.51 <2e-16 ***
## centred_age 268.437 8.816 30.45 <2e-16 ***
## healthNon obese and smoker 13343.999 420.794 31.71 <2e-16 ***
## healthObese and smoker 33334.018 401.955 82.93 <2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 4527 on 1334 degrees of freedom
## Multiple R-squared: 0.8606, Adjusted R-squared: 0.8603
## F-statistic: 2745 on 3 and 1334 DF, p-value: < 2.2e-16
mod.basic %>%
augment() %>%
ggplot(aes(x = centred_age, y = charges)) +
geom_point(aes(col = health)) +
scale_color_manual(values = c("#79a346", "#245373","#964a74")) +
geom_line(aes(y = .fitted)) +
geom_segment(aes(xend = centred_age, yend = .fitted), col = "grey70") +
facet_grid( ~ health)
mod.basic %>%
augment() %>%
ggplot(aes(x = .fitted, y = charges)) +
geom_point(shape = 1, aes(col = health)) +
scale_color_manual(values = c("#79a346", "#245373","#964a74")) +
geom_abline(slope = 1, lty = "dashed")
mod.basic %>%
glance() %>%
dplyr::select(r.squared,
adj.r.squared,
AIC,
BIC,
statistic,
p.value,
df.residual,
nobs) %>%
knitr::kable(align = "c") %>%
kableExtra::kable_styling(full_width = F, position = "left")
r.squared | adj.r.squared | AIC | BIC | statistic | p.value | df.residual | nobs |
---|---|---|---|---|---|---|---|
0.8605842 | 0.8602707 | 26329.01 | 26355 | 2744.834 | 0 | 1334 | 1338 |
plot(mod.basic, which=1)
On peut assumer une relation linéaire entre les prédicteurs et la variable cible.
plot(mod.basic, which=3)
On costance que la variabilité des résidus en fonction des valeurs prédites a une forme quadratique pour un groupe d’invidus, ce qui indique l’hétéroscédasticité (variances non constantes des erreurs)
ce fait on amène à effectuer une transformation sur la variable âge.
mod.age2 = lm(charges ~ I(age ^ 2) + health, data = insurance_data)
summary(mod.age2)
##
## Call:
## lm(formula = charges ~ I(age^2) + health, data = insurance_data)
##
## Residuals:
## Min 1Q Median 3Q Max
## -4980.7 -1975.6 -1425.1 -264.2 23826.2
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 2.553e+03 2.368e+02 10.79 <2e-16 ***
## I(age^2) 3.362e+00 1.098e-01 30.62 <2e-16 ***
## healthNon obese and smoker 1.339e+04 4.199e+02 31.90 <2e-16 ***
## healthObese and smoker 3.331e+04 4.010e+02 83.06 <2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 4516 on 1334 degrees of freedom
## Multiple R-squared: 0.8612, Adjusted R-squared: 0.8609
## F-statistic: 2760 on 3 and 1334 DF, p-value: < 2.2e-16
Performance :
mod.age2 %>%
glance() %>%
dplyr::select(r.squared,
adj.r.squared,
AIC,
BIC,
statistic,
p.value,
df.residual,
nobs) %>%
knitr::kable(align = "c") %>%
kableExtra::kable_styling(full_width = F, position = "left")
r.squared | adj.r.squared | AIC | BIC | statistic | p.value | df.residual | nobs |
---|---|---|---|---|---|---|---|
0.8612288 | 0.8609168 | 26322.81 | 26348.8 | 2759.65 | 0 | 1334 | 1338 |
plot(mod.age2, which = 1)
plot(mod.age2, which = 3)
On peut assumer que les deux premières hypothèses sont bien vérifiés pour ce deuxième modèle.
On va essayer d’améliorer encore la précision en ajoutant d’autres variables explicative :
mod.extend = lm(charges ~ I(age ^ 2) + health + smoker:bmi + children + region +
sex,
data = insurance_data)
summary(mod.extend)
##
## Call:
## lm(formula = charges ~ I(age^2) + health + smoker:bmi + children +
## region + sex, data = insurance_data)
##
## Residuals:
## Min 1Q Median 3Q Max
## -3271 -1496 -1217 -849 24210
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 2189.2344 747.2089 2.930 0.003449 **
## I(age^2) 3.3456 0.1068 31.322 < 2e-16 ***
## healthNon obese and smoker 1461.1728 1960.6927 0.745 0.456263
## healthObese and smoker 16616.0387 2635.5680 6.305 3.93e-10 ***
## children1 886.3802 302.7192 2.928 0.003469 **
## children2 1812.5903 335.3965 5.404 7.71e-08 ***
## children3 1529.6219 393.1487 3.891 0.000105 ***
## children4 4098.2667 890.5092 4.602 4.58e-06 ***
## children5 2174.2555 1046.7760 2.077 0.037985 *
## regionnorthwest -394.3059 342.6695 -1.151 0.250068
## regionsoutheast -895.1742 344.7422 -2.597 0.009518 **
## regionsouthwest -1189.3192 343.6165 -3.461 0.000555 ***
## sexmale -522.6149 239.4042 -2.183 0.029213 *
## smokerno:bmi 14.6839 23.0145 0.638 0.523564
## smokeryes:bmi 486.1988 71.3217 6.817 1.41e-11 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 4352 on 1323 degrees of freedom
## Multiple R-squared: 0.8722, Adjusted R-squared: 0.8708
## F-statistic: 644.8 on 14 and 1323 DF, p-value: < 2.2e-16
En ajoutant d’autres prédicteurs le r-carrée ajustée augmente à 87.08%.
On va tester l’importance des variables sex et region dans la performance du modèle
H0 : l’ajout de la variable n’a aucune importance
anova(lm(charges ~ I(age ^ 2) + health + smoker:bmi + children + sex , data = insurance_data) ,
mod.extend)
## Analysis of Variance Table
##
## Model 1: charges ~ I(age^2) + health + smoker:bmi + children + sex
## Model 2: charges ~ I(age^2) + health + smoker:bmi + children + region +
## sex
## Res.Df RSS Df Sum of Sq F Pr(>F)
## 1 1326 2.5326e+10
## 2 1323 2.5062e+10 3 263504216 4.6366 0.003128 **
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
anova(
lm(charges ~ I(age ^ 2) + health + smoker:bmi + children + region , data = insurance_data) ,
mod.extend
)
## Analysis of Variance Table
##
## Model 1: charges ~ I(age^2) + health + smoker:bmi + children + region
## Model 2: charges ~ I(age^2) + health + smoker:bmi + children + region +
## sex
## Res.Df RSS Df Sum of Sq F Pr(>F)
## 1 1324 2.5153e+10
## 2 1323 2.5062e+10 1 90273982 4.7654 0.02921 *
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Avec un niveau de risque 5%, On rejete l’hypothèse H0 pour les deux cas.
Donc, on garde les deux variables dans le modèle.
Linéarité et normalité des résidus:
par(mfrow = c(1, 3))
for (i in c(1, 3, 2)) {
plot(mod.extend, which = i)
}
la distribution des résidus n’est pas normale.
l’hypothèse de linéarité est vérifiée et la variance semble constante, mais on va effectuer un test pour s’assurer.
Hétéroscedasticité:
library(lmtest)
bptest(mod.extend) # tester l'existence de l'hétéroscédasticité lineaire
##
## studentized Breusch-Pagan test
##
## data: mod.extend
## BP = 16.844, df = 14, p-value = 0.2646
ncvTest(mod.extend)# tester l'existence de l'hétéroscédasticité non lineaire
## Non-constant Variance Score Test
## Variance formula: ~ fitted.values
## Chisquare = 8.567156, Df = 1, p = 0.0034228
Il n’y a aucun signe d’hétéroscédasticité.
par(mfrow = c(2, 2))
for (i in c(4:6)) {
plot(mod.extend, which = i)
}
influencePlot(mod.extend, main = "Influence Plot")
## StudRes Hat CookD
## 322 3.70355755 0.043793647 4.148127e-02
## 517 5.64758608 0.007301576 1.528300e-02
## 1013 4.32167210 0.046084448 5.935982e-02
## 1048 -0.22975456 0.086606152 3.339172e-04
## 1086 0.09609113 0.077335829 5.163424e-05
## 1301 5.54943002 0.018034681 3.687613e-02
On a absence de points très influents dans le modèle, malgré l’existence des points ayant un résidu important mais ne sont pas extrêmes par rapport aux prédicteurs, leurs faible hat-values ne leurs permet pas d’influencer la droite de régression, ainsi que des points extrêmes ayant un faible résidu.
les points ayant les plus grandes valeurs de cook’s distance (mesure d’influence basée sur le leverage et le résidu) sont plus caractérisées par un grand résidu (ordre de grandeur > 3) => l’erreur contribue plus dans l’influence de ces points.
perf=mod.extend %>%
glance() %>%
dplyr::select(r.squared,
adj.r.squared,
AIC,
BIC,
statistic,
p.value,
df.residual,
nobs)
perf%>% knitr::kable(align = "c") %>%
kableExtra::kable_styling(full_width = F, position = "left")
r.squared | adj.r.squared | AIC | BIC | statistic | p.value | df.residual | nobs |
---|---|---|---|---|---|---|---|
0.872179 | 0.8708264 | 26234.83 | 26318.01 | 644.8153 | 0 | 1323 | 1338 |
library(relaimpo)
mod.extend.shapley <-
calc.relimp(lm(charges ~ I(age ^ 2) + health + bmi + children + region +
sex,
data = insurance_data),
type = "lmg")
library(caret)
mod.extend.shapley$lmg %>%
sort(T) %>%
data.frame() -> temp
names(temp) <- c("Shaply_value")
p = temp %>%
cbind(name = rownames(.)) %>%
ggplot()+
geom_col(aes(reorder(name, -Shaply_value), Shaply_value), stat = "identity",
fill = c("#4a9687","#c2aebb","#f5dfb3","#e09e8f","#cfe6b8","#f2d696"), col="black")+
ggtitle("Relative Importance of Predictors")+
theme(plot.title = element_text(hjust = 0.5))+
labs(x="Predictor Labels", y="Shapley Value Regression")
grid.arrange(p ,tableGrob(round(temp,4)), ncol = 2, widths = c(4, 2))
Pour dériver l’importance des variables prédicteurs dans ce modèle. on va utiliser une méthode statistique appelée régression des valeurs shapley qui est une solution issue du concept de théorie des jeux. Son objectif est de répartir équitablement l’importance des prédicteurs dans l’analyse de régression. Étant donné n nombre de variables indépendantes (IV), on va exécuter toutes les combinaisons de modèles de régression linéaire en utilisant cette liste de (IV) par rapport à la variable dépendante (DV) et obtenir le R-carré de chaque modèle. Pour obtenir la mesure de l’importance de chaque variable indépendante (IV), la contribution moyenne au R-carré total de chaque (IV) est calculée en décomposant le R-carré total et en calculant la contribution marginale proportionnelle de chaque (IV).
Supposant qu’on a 2 (IV) A et B et une variable dépendante Y. on doit construire 3 modèles comme suit: 1) Y ~ A 2) Y ~ B 3) Y ~ A + B et chaque modèle aurait leur R-carré respectif .
Pour obtenir la valeur Shapley de A, on va décomposer le r-carré du troisième modèle et calculer la contribution marginale de l’attribut A.
Valeur Shapley (A) = \(\frac{(R^2 (AB) - R^2 (B)) + R^2 (A)}{2}\)
on a utilisé la fonction calc.relimp() du package relaimpo pour déterminer la valeur Shapley de nos prédicteurs.
Les scores de Shapley Value de chaque attribut montrent leur contribution marginale au r carré global (0.8722) du deuxième modèle. on peut donc conclure que, sur la variance totale de 87,22% expliquée par notre modèle, presque 75% de celle-ci est due à l’attribut “health”.
library(mgcv)
# Build the model
mod.gam <- gam(charges ~ s(age) + health + s(bmi) + children + region +
sex,
data = insurance_data)
summary(mod.gam)
##
## Family: gaussian
## Link function: identity
##
## Formula:
## charges ~ s(age) + health + s(bmi) + children + region + sex
##
## Parametric coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 8335.7 314.7 26.486 < 2e-16 ***
## healthNon obese and smoker 13603.2 430.7 31.584 < 2e-16 ***
## healthObese and smoker 33216.6 411.2 80.780 < 2e-16 ***
## children1 970.9 320.6 3.028 0.002508 **
## children2 1844.7 355.5 5.189 2.44e-07 ***
## children3 1593.2 408.7 3.899 0.000102 ***
## children4 4093.3 908.3 4.507 7.17e-06 ***
## children5 2090.6 1067.6 1.958 0.050414 .
## regionnorthwest -327.5 347.8 -0.942 0.346553
## regionsoutheast -786.3 349.2 -2.252 0.024512 *
## regionsouthwest -1162.8 348.6 -3.335 0.000876 ***
## sexmale -481.9 242.5 -1.987 0.047130 *
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Approximate significance of smooth terms:
## edf Ref.df F p-value
## s(age) 3.083 3.832 246.331 <2e-16 ***
## s(bmi) 3.322 4.206 2.001 0.085 .
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## R-sq.(adj) = 0.867 Deviance explained = 86.9%
## GCV = 1.9721e+07 Scale est. = 1.945e+07 n = 1338
mod.gam %>%
glance() %>%
mutate(adj.r.squared = 0.867) %>%
dplyr::select(adj.r.squared, AIC, BIC, nobs) %>%
knitr::kable(align = "c") %>%
kableExtra::kable_styling(full_width = F, position = "left")
adj.r.squared | AIC | BIC | nobs |
---|---|---|---|
0.867 | 26273.5 | 26374.39 | 1338 |
On n’observe aucune amélioration dans le modele additif généralisé par rapport au modéle précèdent.Donc on préfère le modèle linéaire pour des raisons de simplicité et interpretabilité.
Afin de pouvoir évaluer le modèle knn, on va diviser notre jeu de données en deux parties : 80% pour le train set, et 20% pour le test set.
set.seed(12345)
insurance_data_split <-
initial_split(insurance_data, prop = .8, strata = "charges")
insurance_train <- training(insurance_data_split)
insurance_test <- testing(insurance_data_split)
A l’aide du package parsnip, On va instancier un modéle knn pour la regression et l’entrainer sur la train set.
library(parsnip)
library(kknn)
model_knn <- nearest_neighbor(mode = "regression") %>%
set_engine("kknn") %>%
fit(charges ~ age + bmi + children + health +
region, data = insurance_train)
Puis on va utiliser ce modèle pour effectuer les prédictons sur le test set afin d’évaluer sa performance.
model_knn_pred <- model_knn %>%
predict(new_data = insurance_test) %>%
bind_cols(insurance_test %>% dplyr::select(charges))
modelPerform = data.frame(
RMSE = RMSE(model_knn_pred$.pred, model_knn_pred$charges),
R2 = R2(model_knn_pred$.pred, model_knn_pred$charges)
)
modelPerform %>%
knitr::kable(align = "c") %>%
kableExtra::kable_styling(full_width = F, position = "left")
RMSE | R2 |
---|---|
5032.466 | 0.8481833 |
Pour améliorer encore plus la performance, on va essayer de trouver le k optimal.
Pour automatiser les prétraitements, on va commencer par créer une recipe pour standariser les variables numeriques age et bmi, et la combiner avec le modèle knn dans un workflow.
#creating recipe
ins_recipe <-
recipe(charges ~ age + bmi + children + health + region, data = insurance_train) %>%
step_scale(age, bmi) %>%
step_center(age, bmi)
#model specification
ins_spec <-
nearest_neighbor(weight_func = "rectangular", neighbors = tune()) %>%
set_engine("kknn") %>%
set_mode("regression")
#workflow(model+recipe)
ins_wf <- workflow() %>%
add_recipe(ins_recipe) %>%
add_model(ins_spec)
ins_wf
## == Workflow ====================================================================
## Preprocessor: Recipe
## Model: nearest_neighbor()
##
## -- Preprocessor ----------------------------------------------------------------
## 2 Recipe Steps
##
## * step_scale()
## * step_center()
##
## -- Model -----------------------------------------------------------------------
## K-Nearest Neighbor Model Specification (regression)
##
## Main Arguments:
## neighbors = tune()
## weight_func = rectangular
##
## Computational engine: kknn
Pour trouver le k optimale pour le modèle, on va utiliser la méthode de validation du cross-fold avec 5 folds, qui sert à diviser les données sur cinq partie, et à chaque itération réserver une partie pour la validation et utiliser les autres pour l’entrainement, puis calculer la performance en utilisant la moyenne des perfarmances sur les différents folds.
Ici on va chercher le k optimale dans l’intervalle [1,100], et mesurer la performance avec cross-fold validation pour réduire l’erreur due à l’interaction possible entre les variables lors d’une simple division.
set.seed(123456)
ins_vfold <- vfold_cv(insurance_train, v = 5, strata = charges)
gridvals <- tibble(neighbors = seq(1, 100))
ins_results <- ins_wf %>%
tune_grid(resamples = ins_vfold, grid = gridvals) %>%
collect_metrics()
ins_results %>%
knitr::kable(align = "c") %>%
kableExtra::kable_styling(full_width = T, position = "left") %>%
kableExtra::kable_paper() %>%
kableExtra::scroll_box( height = "300px")
neighbors | .metric | .estimator | mean | n | std_err | .config |
---|---|---|---|---|---|---|
1 | rmse | standard | 6178.9140580 | 5 | 173.3195772 | Preprocessor1_Model001 |
1 | rsq | standard | 0.7523178 | 5 | 0.0152306 | Preprocessor1_Model001 |
2 | rmse | standard | 5287.3902576 | 5 | 182.8131372 | Preprocessor1_Model002 |
2 | rsq | standard | 0.8068960 | 5 | 0.0170376 | Preprocessor1_Model002 |
3 | rmse | standard | 5100.9016560 | 5 | 202.0380177 | Preprocessor1_Model003 |
3 | rsq | standard | 0.8185064 | 5 | 0.0169442 | Preprocessor1_Model003 |
4 | rmse | standard | 4963.0848913 | 5 | 227.1004810 | Preprocessor1_Model004 |
4 | rsq | standard | 0.8269390 | 5 | 0.0187191 | Preprocessor1_Model004 |
5 | rmse | standard | 4890.3112905 | 5 | 253.3037270 | Preprocessor1_Model005 |
5 | rsq | standard | 0.8315022 | 5 | 0.0196675 | Preprocessor1_Model005 |
6 | rmse | standard | 4851.6850923 | 5 | 255.4585673 | Preprocessor1_Model006 |
6 | rsq | standard | 0.8336937 | 5 | 0.0195599 | Preprocessor1_Model006 |
7 | rmse | standard | 4775.7779870 | 5 | 252.4518214 | Preprocessor1_Model007 |
7 | rsq | standard | 0.8392938 | 5 | 0.0185528 | Preprocessor1_Model007 |
8 | rmse | standard | 4773.4450843 | 5 | 255.8713731 | Preprocessor1_Model008 |
8 | rsq | standard | 0.8396510 | 5 | 0.0185616 | Preprocessor1_Model008 |
9 | rmse | standard | 4733.5640708 | 5 | 256.0109169 | Preprocessor1_Model009 |
9 | rsq | standard | 0.8427674 | 5 | 0.0180070 | Preprocessor1_Model009 |
10 | rmse | standard | 4715.0323446 | 5 | 263.7646139 | Preprocessor1_Model010 |
10 | rsq | standard | 0.8437694 | 5 | 0.0184588 | Preprocessor1_Model010 |
11 | rmse | standard | 4693.7602162 | 5 | 264.6281078 | Preprocessor1_Model011 |
11 | rsq | standard | 0.8450957 | 5 | 0.0183661 | Preprocessor1_Model011 |
12 | rmse | standard | 4675.7965036 | 5 | 266.1327002 | Preprocessor1_Model012 |
12 | rsq | standard | 0.8460284 | 5 | 0.0184862 | Preprocessor1_Model012 |
13 | rmse | standard | 4658.2517213 | 5 | 264.0859288 | Preprocessor1_Model013 |
13 | rsq | standard | 0.8470701 | 5 | 0.0183785 | Preprocessor1_Model013 |
14 | rmse | standard | 4653.9054267 | 5 | 267.4354491 | Preprocessor1_Model014 |
14 | rsq | standard | 0.8472995 | 5 | 0.0184257 | Preprocessor1_Model014 |
15 | rmse | standard | 4649.7572197 | 5 | 274.8265877 | Preprocessor1_Model015 |
15 | rsq | standard | 0.8474079 | 5 | 0.0189709 | Preprocessor1_Model015 |
16 | rmse | standard | 4663.8376805 | 5 | 270.9116049 | Preprocessor1_Model016 |
16 | rsq | standard | 0.8465355 | 5 | 0.0188114 | Preprocessor1_Model016 |
17 | rmse | standard | 4646.5558310 | 5 | 265.4033259 | Preprocessor1_Model017 |
17 | rsq | standard | 0.8476093 | 5 | 0.0184269 | Preprocessor1_Model017 |
18 | rmse | standard | 4648.9109798 | 5 | 258.6419467 | Preprocessor1_Model018 |
18 | rsq | standard | 0.8476848 | 5 | 0.0178239 | Preprocessor1_Model018 |
19 | rmse | standard | 4640.7203726 | 5 | 260.4874734 | Preprocessor1_Model019 |
19 | rsq | standard | 0.8482083 | 5 | 0.0178772 | Preprocessor1_Model019 |
20 | rmse | standard | 4629.8374928 | 5 | 263.9582032 | Preprocessor1_Model020 |
20 | rsq | standard | 0.8490991 | 5 | 0.0178641 | Preprocessor1_Model020 |
21 | rmse | standard | 4629.1660454 | 5 | 260.0895097 | Preprocessor1_Model021 |
21 | rsq | standard | 0.8491335 | 5 | 0.0176451 | Preprocessor1_Model021 |
22 | rmse | standard | 4634.4062304 | 5 | 262.3567633 | Preprocessor1_Model022 |
22 | rsq | standard | 0.8489441 | 5 | 0.0177543 | Preprocessor1_Model022 |
23 | rmse | standard | 4647.4361189 | 5 | 261.3037340 | Preprocessor1_Model023 |
23 | rsq | standard | 0.8481788 | 5 | 0.0177749 | Preprocessor1_Model023 |
24 | rmse | standard | 4655.3149210 | 5 | 258.6656146 | Preprocessor1_Model024 |
24 | rsq | standard | 0.8476871 | 5 | 0.0176269 | Preprocessor1_Model024 |
25 | rmse | standard | 4654.8456030 | 5 | 257.9780451 | Preprocessor1_Model025 |
25 | rsq | standard | 0.8477278 | 5 | 0.0176100 | Preprocessor1_Model025 |
26 | rmse | standard | 4669.6677147 | 5 | 262.0185012 | Preprocessor1_Model026 |
26 | rsq | standard | 0.8467377 | 5 | 0.0179200 | Preprocessor1_Model026 |
27 | rmse | standard | 4675.1333349 | 5 | 253.9520483 | Preprocessor1_Model027 |
27 | rsq | standard | 0.8467449 | 5 | 0.0173183 | Preprocessor1_Model027 |
28 | rmse | standard | 4686.8939256 | 5 | 251.7401776 | Preprocessor1_Model028 |
28 | rsq | standard | 0.8461133 | 5 | 0.0171720 | Preprocessor1_Model028 |
29 | rmse | standard | 4697.4520238 | 5 | 251.6249278 | Preprocessor1_Model029 |
29 | rsq | standard | 0.8456957 | 5 | 0.0170900 | Preprocessor1_Model029 |
30 | rmse | standard | 4697.7346408 | 5 | 247.4222881 | Preprocessor1_Model030 |
30 | rsq | standard | 0.8460997 | 5 | 0.0168078 | Preprocessor1_Model030 |
31 | rmse | standard | 4707.8854025 | 5 | 248.8065567 | Preprocessor1_Model031 |
31 | rsq | standard | 0.8457563 | 5 | 0.0166483 | Preprocessor1_Model031 |
32 | rmse | standard | 4725.2919175 | 5 | 248.3321147 | Preprocessor1_Model032 |
32 | rsq | standard | 0.8451747 | 5 | 0.0165526 | Preprocessor1_Model032 |
33 | rmse | standard | 4735.5588714 | 5 | 252.6255563 | Preprocessor1_Model033 |
33 | rsq | standard | 0.8450080 | 5 | 0.0167783 | Preprocessor1_Model033 |
34 | rmse | standard | 4746.0761373 | 5 | 250.6010354 | Preprocessor1_Model034 |
34 | rsq | standard | 0.8450294 | 5 | 0.0165461 | Preprocessor1_Model034 |
35 | rmse | standard | 4762.7821750 | 5 | 252.4975171 | Preprocessor1_Model035 |
35 | rsq | standard | 0.8446822 | 5 | 0.0167139 | Preprocessor1_Model035 |
36 | rmse | standard | 4766.5267641 | 5 | 255.1530408 | Preprocessor1_Model036 |
36 | rsq | standard | 0.8452888 | 5 | 0.0168055 | Preprocessor1_Model036 |
37 | rmse | standard | 4782.1438923 | 5 | 256.0248366 | Preprocessor1_Model037 |
37 | rsq | standard | 0.8448535 | 5 | 0.0167685 | Preprocessor1_Model037 |
38 | rmse | standard | 4806.4956334 | 5 | 254.7956850 | Preprocessor1_Model038 |
38 | rsq | standard | 0.8440514 | 5 | 0.0167226 | Preprocessor1_Model038 |
39 | rmse | standard | 4818.2964146 | 5 | 257.1599681 | Preprocessor1_Model039 |
39 | rsq | standard | 0.8439850 | 5 | 0.0167531 | Preprocessor1_Model039 |
40 | rmse | standard | 4820.8396510 | 5 | 254.8822517 | Preprocessor1_Model040 |
40 | rsq | standard | 0.8447060 | 5 | 0.0165704 | Preprocessor1_Model040 |
41 | rmse | standard | 4840.8191847 | 5 | 256.0101246 | Preprocessor1_Model041 |
41 | rsq | standard | 0.8439965 | 5 | 0.0166757 | Preprocessor1_Model041 |
42 | rmse | standard | 4852.4308016 | 5 | 254.5219178 | Preprocessor1_Model042 |
42 | rsq | standard | 0.8438103 | 5 | 0.0166872 | Preprocessor1_Model042 |
43 | rmse | standard | 4867.6603943 | 5 | 254.1660848 | Preprocessor1_Model043 |
43 | rsq | standard | 0.8441141 | 5 | 0.0163858 | Preprocessor1_Model043 |
44 | rmse | standard | 4880.2858251 | 5 | 254.0490051 | Preprocessor1_Model044 |
44 | rsq | standard | 0.8442817 | 5 | 0.0163405 | Preprocessor1_Model044 |
45 | rmse | standard | 4896.9124752 | 5 | 256.7190085 | Preprocessor1_Model045 |
45 | rsq | standard | 0.8441670 | 5 | 0.0165102 | Preprocessor1_Model045 |
46 | rmse | standard | 4907.1107505 | 5 | 251.0990212 | Preprocessor1_Model046 |
46 | rsq | standard | 0.8442302 | 5 | 0.0161694 | Preprocessor1_Model046 |
47 | rmse | standard | 4934.1177870 | 5 | 250.3514229 | Preprocessor1_Model047 |
47 | rsq | standard | 0.8434454 | 5 | 0.0162412 | Preprocessor1_Model047 |
48 | rmse | standard | 4950.5952668 | 5 | 250.8397584 | Preprocessor1_Model048 |
48 | rsq | standard | 0.8431428 | 5 | 0.0163417 | Preprocessor1_Model048 |
49 | rmse | standard | 4971.3220377 | 5 | 252.7660356 | Preprocessor1_Model049 |
49 | rsq | standard | 0.8427938 | 5 | 0.0164434 | Preprocessor1_Model049 |
50 | rmse | standard | 4991.3165169 | 5 | 254.1472657 | Preprocessor1_Model050 |
50 | rsq | standard | 0.8424261 | 5 | 0.0164899 | Preprocessor1_Model050 |
51 | rmse | standard | 5010.3814183 | 5 | 258.9393462 | Preprocessor1_Model051 |
51 | rsq | standard | 0.8419906 | 5 | 0.0168380 | Preprocessor1_Model051 |
52 | rmse | standard | 5032.7688162 | 5 | 256.3764770 | Preprocessor1_Model052 |
52 | rsq | standard | 0.8419245 | 5 | 0.0168224 | Preprocessor1_Model052 |
53 | rmse | standard | 5056.2074325 | 5 | 255.6109658 | Preprocessor1_Model053 |
53 | rsq | standard | 0.8415965 | 5 | 0.0167222 | Preprocessor1_Model053 |
54 | rmse | standard | 5070.5846465 | 5 | 256.6988728 | Preprocessor1_Model054 |
54 | rsq | standard | 0.8417833 | 5 | 0.0168816 | Preprocessor1_Model054 |
55 | rmse | standard | 5092.3323325 | 5 | 254.2759554 | Preprocessor1_Model055 |
55 | rsq | standard | 0.8417379 | 5 | 0.0166438 | Preprocessor1_Model055 |
56 | rmse | standard | 5113.9907711 | 5 | 252.9157752 | Preprocessor1_Model056 |
56 | rsq | standard | 0.8418789 | 5 | 0.0168044 | Preprocessor1_Model056 |
57 | rmse | standard | 5141.3076935 | 5 | 253.8617620 | Preprocessor1_Model057 |
57 | rsq | standard | 0.8410924 | 5 | 0.0169582 | Preprocessor1_Model057 |
58 | rmse | standard | 5164.6103863 | 5 | 254.3423992 | Preprocessor1_Model058 |
58 | rsq | standard | 0.8409587 | 5 | 0.0171882 | Preprocessor1_Model058 |
59 | rmse | standard | 5187.1686629 | 5 | 255.3040686 | Preprocessor1_Model059 |
59 | rsq | standard | 0.8407301 | 5 | 0.0174055 | Preprocessor1_Model059 |
60 | rmse | standard | 5206.3796593 | 5 | 255.0693833 | Preprocessor1_Model060 |
60 | rsq | standard | 0.8407261 | 5 | 0.0173943 | Preprocessor1_Model060 |
61 | rmse | standard | 5226.1861062 | 5 | 251.1256514 | Preprocessor1_Model061 |
61 | rsq | standard | 0.8408629 | 5 | 0.0173581 | Preprocessor1_Model061 |
62 | rmse | standard | 5246.6497017 | 5 | 250.4438026 | Preprocessor1_Model062 |
62 | rsq | standard | 0.8408098 | 5 | 0.0172474 | Preprocessor1_Model062 |
63 | rmse | standard | 5260.6746375 | 5 | 253.8121305 | Preprocessor1_Model063 |
63 | rsq | standard | 0.8411199 | 5 | 0.0173674 | Preprocessor1_Model063 |
64 | rmse | standard | 5282.6943293 | 5 | 255.2813460 | Preprocessor1_Model064 |
64 | rsq | standard | 0.8409379 | 5 | 0.0173612 | Preprocessor1_Model064 |
65 | rmse | standard | 5310.0626579 | 5 | 255.2820449 | Preprocessor1_Model065 |
65 | rsq | standard | 0.8403366 | 5 | 0.0176917 | Preprocessor1_Model065 |
66 | rmse | standard | 5330.5619104 | 5 | 255.2167332 | Preprocessor1_Model066 |
66 | rsq | standard | 0.8401139 | 5 | 0.0178185 | Preprocessor1_Model066 |
67 | rmse | standard | 5345.0835302 | 5 | 254.7457515 | Preprocessor1_Model067 |
67 | rsq | standard | 0.8401665 | 5 | 0.0179959 | Preprocessor1_Model067 |
68 | rmse | standard | 5359.6814118 | 5 | 250.1877437 | Preprocessor1_Model068 |
68 | rsq | standard | 0.8405975 | 5 | 0.0177160 | Preprocessor1_Model068 |
69 | rmse | standard | 5390.7338072 | 5 | 248.9819478 | Preprocessor1_Model069 |
69 | rsq | standard | 0.8402191 | 5 | 0.0177109 | Preprocessor1_Model069 |
70 | rmse | standard | 5406.9539519 | 5 | 247.6082823 | Preprocessor1_Model070 |
70 | rsq | standard | 0.8405627 | 5 | 0.0176769 | Preprocessor1_Model070 |
71 | rmse | standard | 5424.0772491 | 5 | 249.3965793 | Preprocessor1_Model071 |
71 | rsq | standard | 0.8404457 | 5 | 0.0177430 | Preprocessor1_Model071 |
72 | rmse | standard | 5453.3601870 | 5 | 245.8968002 | Preprocessor1_Model072 |
72 | rsq | standard | 0.8397146 | 5 | 0.0176040 | Preprocessor1_Model072 |
73 | rmse | standard | 5474.6071400 | 5 | 243.8752587 | Preprocessor1_Model073 |
73 | rsq | standard | 0.8394817 | 5 | 0.0176707 | Preprocessor1_Model073 |
74 | rmse | standard | 5500.3324813 | 5 | 238.5758351 | Preprocessor1_Model074 |
74 | rsq | standard | 0.8389886 | 5 | 0.0176102 | Preprocessor1_Model074 |
75 | rmse | standard | 5533.2174181 | 5 | 235.5433973 | Preprocessor1_Model075 |
75 | rsq | standard | 0.8385261 | 5 | 0.0176846 | Preprocessor1_Model075 |
76 | rmse | standard | 5560.6640148 | 5 | 235.6440023 | Preprocessor1_Model076 |
76 | rsq | standard | 0.8381477 | 5 | 0.0179443 | Preprocessor1_Model076 |
77 | rmse | standard | 5584.7835166 | 5 | 236.0315921 | Preprocessor1_Model077 |
77 | rsq | standard | 0.8377637 | 5 | 0.0181491 | Preprocessor1_Model077 |
78 | rmse | standard | 5604.5919775 | 5 | 236.6764084 | Preprocessor1_Model078 |
78 | rsq | standard | 0.8377688 | 5 | 0.0182580 | Preprocessor1_Model078 |
79 | rmse | standard | 5633.3447960 | 5 | 235.5720509 | Preprocessor1_Model079 |
79 | rsq | standard | 0.8372383 | 5 | 0.0183068 | Preprocessor1_Model079 |
80 | rmse | standard | 5649.1854764 | 5 | 233.0229942 | Preprocessor1_Model080 |
80 | rsq | standard | 0.8376092 | 5 | 0.0183126 | Preprocessor1_Model080 |
81 | rmse | standard | 5669.7011509 | 5 | 227.1266987 | Preprocessor1_Model081 |
81 | rsq | standard | 0.8380915 | 5 | 0.0180356 | Preprocessor1_Model081 |
82 | rmse | standard | 5692.7484482 | 5 | 218.4712738 | Preprocessor1_Model082 |
82 | rsq | standard | 0.8383354 | 5 | 0.0177220 | Preprocessor1_Model082 |
83 | rmse | standard | 5715.0654889 | 5 | 217.8337115 | Preprocessor1_Model083 |
83 | rsq | standard | 0.8382711 | 5 | 0.0177845 | Preprocessor1_Model083 |
84 | rmse | standard | 5737.5170256 | 5 | 217.0799424 | Preprocessor1_Model084 |
84 | rsq | standard | 0.8381157 | 5 | 0.0177577 | Preprocessor1_Model084 |
85 | rmse | standard | 5765.2142956 | 5 | 215.0876605 | Preprocessor1_Model085 |
85 | rsq | standard | 0.8381152 | 5 | 0.0176802 | Preprocessor1_Model085 |
86 | rmse | standard | 5790.4235698 | 5 | 210.4478847 | Preprocessor1_Model086 |
86 | rsq | standard | 0.8379283 | 5 | 0.0174265 | Preprocessor1_Model086 |
87 | rmse | standard | 5820.4823630 | 5 | 207.4782758 | Preprocessor1_Model087 |
87 | rsq | standard | 0.8379163 | 5 | 0.0172870 | Preprocessor1_Model087 |
88 | rmse | standard | 5846.0684056 | 5 | 204.0553549 | Preprocessor1_Model088 |
88 | rsq | standard | 0.8375326 | 5 | 0.0172937 | Preprocessor1_Model088 |
89 | rmse | standard | 5872.1645360 | 5 | 198.7234937 | Preprocessor1_Model089 |
89 | rsq | standard | 0.8370917 | 5 | 0.0170569 | Preprocessor1_Model089 |
90 | rmse | standard | 5893.2743975 | 5 | 196.8409955 | Preprocessor1_Model090 |
90 | rsq | standard | 0.8370979 | 5 | 0.0169406 | Preprocessor1_Model090 |
91 | rmse | standard | 5921.1560325 | 5 | 195.0947100 | Preprocessor1_Model091 |
91 | rsq | standard | 0.8372950 | 5 | 0.0169937 | Preprocessor1_Model091 |
92 | rmse | standard | 5949.2206657 | 5 | 188.9127415 | Preprocessor1_Model092 |
92 | rsq | standard | 0.8372605 | 5 | 0.0168327 | Preprocessor1_Model092 |
93 | rmse | standard | 5974.8455124 | 5 | 185.1585418 | Preprocessor1_Model093 |
93 | rsq | standard | 0.8372469 | 5 | 0.0169116 | Preprocessor1_Model093 |
94 | rmse | standard | 6002.5265350 | 5 | 183.7416926 | Preprocessor1_Model094 |
94 | rsq | standard | 0.8369820 | 5 | 0.0168052 | Preprocessor1_Model094 |
95 | rmse | standard | 6033.7125940 | 5 | 182.7454155 | Preprocessor1_Model095 |
95 | rsq | standard | 0.8364537 | 5 | 0.0168431 | Preprocessor1_Model095 |
96 | rmse | standard | 6058.9580811 | 5 | 181.8945963 | Preprocessor1_Model096 |
96 | rsq | standard | 0.8368205 | 5 | 0.0168068 | Preprocessor1_Model096 |
97 | rmse | standard | 6089.5336705 | 5 | 181.2305913 | Preprocessor1_Model097 |
97 | rsq | standard | 0.8366151 | 5 | 0.0167848 | Preprocessor1_Model097 |
98 | rmse | standard | 6118.9537048 | 5 | 178.0149517 | Preprocessor1_Model098 |
98 | rsq | standard | 0.8367219 | 5 | 0.0166843 | Preprocessor1_Model098 |
99 | rmse | standard | 6148.5943339 | 5 | 176.6887710 | Preprocessor1_Model099 |
99 | rsq | standard | 0.8367294 | 5 | 0.0166779 | Preprocessor1_Model099 |
100 | rmse | standard | 6181.5358696 | 5 | 171.9513420 | Preprocessor1_Model100 |
100 | rsq | standard | 0.8367538 | 5 | 0.0165249 | Preprocessor1_Model100 |
En comparant le RMSE pour les différents k,on peut constater que si on prend trop ou très peu de voisins, nous obtenons une faible précision.
ins_results %>%
filter(.metric == "rmse") %>%
ggplot(aes(neighbors, mean)) +
geom_line(col="red") +
geom_point(col = "#336980")+
labs(title="Residuals vs number of neighbours")
Alors qu’on trouve que la valeur optimale est de considérer 19 neighbours, chose qui peut être interprété comme suit : pour prédire les charges d’un assuré, il vaut mieux voir les 19 plus proches observations des assurés pour les variables age + bmi + children + health + region.
ins_min <- ins_results %>%
filter(.metric == "rmse") %>%
filter(mean == min(mean))
ins_min %>%
knitr::kable(align = "c") %>%
kableExtra::kable_styling(full_width = F, position = "left")
neighbors | .metric | .estimator | mean | n | std_err | .config |
---|---|---|---|---|---|---|
21 | rmse | standard | 4629.166 | 5 | 260.0895 | Preprocessor1_Model021 |
kmin <- ins_min %>% pull(neighbors)
maintenant on va re-entrainer notre modèle avec la valeur optimale de k=21, et évaluer sa performance.
set.seed(123456)
ins_spec <-
nearest_neighbor(weight_func = "rectangular", neighbors = kmin) %>%
set_engine("kknn") %>%
set_mode("regression")
ins_fit <- workflow() %>%
add_recipe(ins_recipe) %>%
add_model(ins_spec) %>%
fit(data = insurance_train)
ins_fit
## == Workflow [trained] ==========================================================
## Preprocessor: Recipe
## Model: nearest_neighbor()
##
## -- Preprocessor ----------------------------------------------------------------
## 2 Recipe Steps
##
## * step_scale()
## * step_center()
##
## -- Model -----------------------------------------------------------------------
##
## Call:
## kknn::train.kknn(formula = ..y ~ ., data = data, ks = min_rows(21L, data, 5), kernel = ~"rectangular")
##
## Type of response variable: continuous
## minimal mean absolute error: 2614.916
## Minimal mean squared error: 21519995
## Best kernel: rectangular
## Best k: 21
ins_res <- ins_fit %>%
predict(insurance_test) %>%
bind_cols(insurance_test)
ins_res %>% ggplot(aes(y = charges, x = .pred, col = health)) +
geom_point() +
scale_color_manual(values = c("#79a346", "#245373","#964a74")) +
geom_abline(slope = 1)
ins_summary <- ins_res %>%
metrics(truth = charges, estimate = .pred)
ins_summary %>%
knitr::kable(align = "c") %>%
kableExtra::kable_styling(full_width = F, position = "left")
.metric | .estimator | .estimate |
---|---|---|
rmse | standard | 4642.8788742 |
rsq | standard | 0.8698845 |
mae | standard | 2675.5606994 |
On va commencer par entrainer un modèle de “random forest” sur le training-set en utilisant l’ensemble des variables et en gardant les paramètres par défaut qu’on va évaluer sur le test-set.
library(randomForest)
model_rand_forest <- rand_forest(mode = "regression") %>%
set_engine("randomForest") %>%
fit(charges ~ age + bmi + children + health +
region + sex, data = insurance_train)
model_rand_forest_pred <- model_rand_forest %>%
predict(new_data = insurance_test) %>%
bind_cols(insurance_test %>% dplyr::select(charges))
model_rand_forest_pred %>%
knitr::kable(align = "c") %>%
kableExtra::kable_styling(full_width = F, position = "left") %>%
# kableExtra::kable_paper() %>%
kableExtra::scroll_box(width = "230px", height = "300px")
.pred | charges |
---|---|
10451.433 | 6406.411 |
3660.008 | 2198.190 |
6335.282 | 4687.797 |
2650.023 | 1625.434 |
5255.351 | 3046.062 |
5914.029 | 4949.759 |
7127.264 | 6313.759 |
8745.097 | 6079.672 |
24201.908 | 23568.272 |
37416.161 | 37742.576 |
45024.699 | 47496.494 |
8077.114 | 5989.524 |
3122.390 | 1743.214 |
7023.986 | 5920.104 |
19461.594 | 16577.780 |
3258.499 | 1532.470 |
22657.791 | 21098.554 |
12131.086 | 8026.667 |
17818.219 | 15820.699 |
6291.896 | 5003.853 |
6277.419 | 4646.759 |
12104.121 | 11488.317 |
2753.991 | 1705.624 |
4696.097 | 3385.399 |
16312.455 | 32734.186 |
8002.223 | 6082.405 |
13703.279 | 12815.445 |
3901.848 | 2457.211 |
3712.193 | 1842.519 |
21616.845 | 19964.746 |
10121.862 | 6948.701 |
6680.250 | 5152.134 |
12675.568 | 10407.086 |
11412.459 | 8116.680 |
6258.789 | 4005.423 |
9949.955 | 7419.478 |
39358.065 | 43753.337 |
5224.614 | 4883.866 |
3559.738 | 1639.563 |
3799.573 | 2130.676 |
36462.321 | 37133.898 |
11101.341 | 7147.105 |
5816.038 | 1980.070 |
9060.899 | 8520.026 |
8562.868 | 7371.772 |
5973.723 | 2483.736 |
5792.565 | 5253.524 |
23597.827 | 19515.542 |
5301.054 | 2689.495 |
11767.640 | 24227.337 |
7042.353 | 6710.192 |
18616.737 | 19444.266 |
9463.841 | 7152.671 |
2796.933 | 1832.094 |
41704.425 | 41097.162 |
14748.181 | 13047.332 |
33955.399 | 33750.292 |
13813.137 | 20462.998 |
42807.802 | 46151.124 |
14392.991 | 14590.632 |
11574.902 | 9282.481 |
11499.566 | 9617.662 |
13989.587 | 12928.791 |
44759.885 | 48549.178 |
6240.538 | 4237.127 |
11742.170 | 9625.920 |
13794.900 | 9432.925 |
5927.263 | 3172.018 |
38474.315 | 38746.355 |
10729.208 | 9249.495 |
5917.838 | 20177.671 |
6158.480 | 4151.029 |
11956.411 | 8444.474 |
10554.241 | 8835.265 |
11062.991 | 7421.195 |
6078.886 | 4894.753 |
43443.660 | 47928.030 |
16418.889 | 13937.666 |
16314.367 | 13217.094 |
14569.062 | 13981.850 |
5953.594 | 3554.203 |
2906.175 | 14133.038 |
11649.409 | 10043.249 |
7068.275 | 3180.510 |
8246.913 | 3481.868 |
15812.307 | 16455.708 |
42067.931 | 42303.692 |
6874.941 | 5846.918 |
11622.378 | 8302.536 |
4344.020 | 1261.859 |
9964.540 | 9264.797 |
21168.982 | 19594.810 |
5642.873 | 2727.395 |
10590.816 | 8968.330 |
10943.152 | 9788.866 |
6937.360 | 18804.752 |
7527.599 | 5969.723 |
4516.087 | 2254.797 |
6912.075 | 5926.846 |
36192.333 | 37079.372 |
3844.132 | 1149.396 |
16017.208 | 12731.000 |
13062.726 | 11454.022 |
4284.787 | 2497.038 |
11150.924 | 9563.029 |
5604.969 | 4347.023 |
12271.593 | 12475.351 |
42997.411 | 48885.136 |
11686.132 | 11455.280 |
3781.670 | 27724.289 |
10989.137 | 8413.463 |
12825.503 | 9861.025 |
6787.646 | 5972.378 |
7088.685 | 6196.448 |
14239.052 | 13887.204 |
14338.218 | 10231.500 |
6978.112 | 3268.847 |
45184.647 | 45863.205 |
6187.720 | 3935.180 |
3648.610 | 1646.430 |
10118.581 | 9058.730 |
5073.998 | 2128.431 |
8209.503 | 7256.723 |
3308.223 | 1665.000 |
6452.114 | 3861.210 |
34883.391 | 43943.876 |
14354.402 | 13635.638 |
8397.619 | 7640.309 |
8291.151 | 5594.846 |
2957.054 | 1633.044 |
12082.887 | 10713.644 |
11204.735 | 9182.170 |
6266.344 | 11326.715 |
10128.842 | 9391.346 |
14202.350 | 14410.932 |
3844.630 | 2709.112 |
20321.587 | 20149.323 |
22536.564 | 32787.459 |
4402.245 | 1712.227 |
12567.328 | 12430.953 |
22366.704 | 24667.419 |
5513.127 | 3410.324 |
13856.903 | 12363.547 |
11511.613 | 10156.783 |
12414.191 | 29186.482 |
39494.548 | 40273.645 |
12416.558 | 10976.246 |
11902.851 | 9541.696 |
6760.476 | 5375.038 |
8204.819 | 6113.231 |
5890.396 | 5469.007 |
43028.597 | 43254.418 |
19710.618 | 19539.243 |
4854.934 | 2416.955 |
13832.884 | 14319.031 |
7137.109 | 6933.242 |
12375.762 | 9869.810 |
23249.079 | 24520.264 |
11897.721 | 11848.141 |
7031.613 | 28476.735 |
9815.333 | 9414.920 |
5687.986 | 2842.761 |
36363.311 | 55135.402 |
8414.867 | 7445.918 |
3901.651 | 1621.883 |
6525.474 | 4719.737 |
5189.801 | 1526.312 |
13131.374 | 12323.936 |
35773.630 | 36021.011 |
4439.393 | 2974.126 |
12839.126 | 10601.632 |
3055.762 | 1141.445 |
10602.393 | 8457.818 |
5161.794 | 3392.365 |
12671.686 | 26140.360 |
4676.341 | 2789.057 |
8156.097 | 4877.981 |
19094.409 | 19719.695 |
6167.933 | 5272.176 |
13393.098 | 13063.883 |
6142.009 | 4661.286 |
13871.941 | 14382.709 |
5736.741 | 2473.334 |
6369.748 | 5245.227 |
13169.838 | 13462.520 |
12201.263 | 25333.333 |
6609.788 | 2913.569 |
7975.012 | 6289.755 |
12478.543 | 10096.970 |
9088.949 | 7348.142 |
14399.332 | 9487.644 |
38840.379 | 39047.285 |
37484.189 | 38998.546 |
4215.930 | 12609.887 |
10440.395 | 9500.573 |
18650.951 | 19199.944 |
7654.150 | 4915.060 |
4152.923 | 3378.910 |
6302.337 | 5484.467 |
16418.292 | 16420.495 |
8125.900 | 7986.475 |
10332.652 | 8627.541 |
7084.050 | 4438.263 |
6529.118 | 3987.926 |
14316.166 | 12495.291 |
3889.998 | 1711.027 |
3338.757 | 2020.177 |
44657.913 | 44423.803 |
15168.442 | 13747.872 |
6549.855 | 22493.660 |
40960.798 | 39727.614 |
9363.662 | 17929.303 |
4585.558 | 2480.979 |
46105.477 | 48970.248 |
10449.176 | 8978.185 |
13405.882 | 13204.286 |
15911.262 | 15161.534 |
12464.283 | 11353.228 |
12260.997 | 9748.911 |
5869.858 | 20420.605 |
11022.017 | 8988.159 |
13679.689 | 10493.946 |
39453.551 | 38282.749 |
34293.996 | 34166.273 |
16221.916 | 14254.608 |
11720.541 | 9991.038 |
12215.450 | 11085.587 |
9545.484 | 7623.518 |
5902.753 | 2203.736 |
40335.488 | 40941.285 |
5958.347 | 3989.841 |
9031.146 | 7727.253 |
12087.194 | 10982.501 |
5837.904 | 2899.489 |
19191.890 | 19350.369 |
10388.267 | 7650.774 |
5535.421 | 2850.684 |
6377.622 | 2632.992 |
11864.433 | 9447.382 |
13020.834 | 36910.608 |
37660.758 | 38415.474 |
19573.251 | 20296.863 |
4350.166 | 12890.058 |
39993.455 | 41661.602 |
7401.246 | 7537.164 |
8968.118 | 7162.012 |
7544.543 | 6985.507 |
46090.708 | 47269.854 |
2862.347 | 1633.962 |
38734.561 | 37607.528 |
15476.619 | 30063.581 |
8276.101 | 7337.748 |
34502.953 | 34254.053 |
4922.545 | 3292.530 |
11391.260 | 10959.330 |
22680.237 | 24535.699 |
44181.887 | 47403.880 |
12158.897 | 8534.672 |
12102.118 | 11931.125 |
43595.950 | 46718.163 |
7686.517 | 2464.619 |
7341.964 | 7201.701 |
8920.595 | 22395.744 |
12315.288 | 10325.206 |
modelPerform_rand_forest = data.frame(
RMSE = RMSE(
model_rand_forest_pred$.pred,
model_rand_forest_pred$charges
),
R2 = R2(
model_rand_forest_pred$.pred,
model_rand_forest_pred$charges
)
)
modelPerform_rand_forest %>%
knitr::kable(align = "c") %>%
kableExtra::kable_styling(full_width = F, position = "left")
RMSE | R2 |
---|---|
4720.893 | 0.8689894 |
Ce modèle a réussi à expliquer 86,81% de la variance totale avec un RMSE de 4726.053.
On va essayer maintenant d’améliorer ces performances en réglant les hyperparamètres min_n(nombre minimale d’observations requis dans un noeud pour le diviser) et mtry(Le nombre de prédicteurs qui seront échantillonnés au hasard à chaque division lors de la création des modèles d’arbre).
Pour automatiser les prétraitements, on va utiliser une recipe pour centrer et réduire les variables numeriques age et bmi, et la combiner avec la spécification du modèle dans un workflow.
#creating recipe
ins_recipe2 <-
recipe(charges ~ age + bmi + children + health + region + sex, data = insurance_train) %>%
step_scale(age, bmi) %>%
step_center(age, bmi)
#model specification
ins_spec_rf <- rand_forest(min_n = tune(),mtry=tune()) %>%
set_engine("randomForest") %>%
set_mode("regression")
#workflow(model+recipe)
ins_wf_rf <- workflow() %>%
add_recipe(ins_recipe2) %>%
add_model(ins_spec_rf)
ins_wf_rf
## == Workflow ====================================================================
## Preprocessor: Recipe
## Model: rand_forest()
##
## -- Preprocessor ----------------------------------------------------------------
## 2 Recipe Steps
##
## * step_scale()
## * step_center()
##
## -- Model -----------------------------------------------------------------------
## Random Forest Model Specification (regression)
##
## Main Arguments:
## mtry = tune()
## min_n = tune()
##
## Computational engine: randomForest
Tuning des hyperparamètres :
On va utiliser la méthode “Grid search” Pour pouvoir trouver la combinaison optimale des hyperparamètres “mtry” et “min_n”.
set.seed(123456)
ins_results_rf_grid <- ins_wf_rf %>%
tune_grid(resamples = ins_vfold, grid = data_grid)
ins_results_rf <- ins_results_rf_grid%>%
collect_metrics()
ins_results_rf %>%
knitr::kable(align = "c") %>%
kableExtra::kable_styling(full_width = T, position = "left") %>%
kableExtra::kable_paper() %>%
kableExtra::scroll_box( height = "300px")
mtry | min_n | .metric | .estimator | mean | n | std_err | .config |
---|---|---|---|---|---|---|---|
3 | 26 | rmse | standard | 4502.9713269 | 5 | 246.9508460 | Preprocessor1_Model01 |
3 | 26 | rsq | standard | 0.8572369 | 5 | 0.0178017 | Preprocessor1_Model01 |
4 | 26 | rmse | standard | 4486.3605136 | 5 | 247.3486272 | Preprocessor1_Model02 |
4 | 26 | rsq | standard | 0.8575817 | 5 | 0.0176938 | Preprocessor1_Model02 |
5 | 26 | rmse | standard | 4507.6436824 | 5 | 241.7663073 | Preprocessor1_Model03 |
5 | 26 | rsq | standard | 0.8565050 | 5 | 0.0174662 | Preprocessor1_Model03 |
6 | 26 | rmse | standard | 4540.8392467 | 5 | 238.8689672 | Preprocessor1_Model04 |
6 | 26 | rsq | standard | 0.8547268 | 5 | 0.0173503 | Preprocessor1_Model04 |
3 | 28 | rmse | standard | 4492.4823224 | 5 | 246.0261177 | Preprocessor1_Model05 |
3 | 28 | rsq | standard | 0.8579409 | 5 | 0.0174977 | Preprocessor1_Model05 |
4 | 28 | rmse | standard | 4480.4196667 | 5 | 248.8713755 | Preprocessor1_Model06 |
4 | 28 | rsq | standard | 0.8579075 | 5 | 0.0178142 | Preprocessor1_Model06 |
5 | 28 | rmse | standard | 4506.2273528 | 5 | 242.0233731 | Preprocessor1_Model07 |
5 | 28 | rsq | standard | 0.8565265 | 5 | 0.0174745 | Preprocessor1_Model07 |
6 | 28 | rmse | standard | 4531.7370785 | 5 | 240.9903022 | Preprocessor1_Model08 |
6 | 28 | rsq | standard | 0.8551576 | 5 | 0.0174482 | Preprocessor1_Model08 |
3 | 30 | rmse | standard | 4502.4015236 | 5 | 253.0207005 | Preprocessor1_Model09 |
3 | 30 | rsq | standard | 0.8572583 | 5 | 0.0180438 | Preprocessor1_Model09 |
4 | 30 | rmse | standard | 4480.9065073 | 5 | 252.5024903 | Preprocessor1_Model10 |
4 | 30 | rsq | standard | 0.8577283 | 5 | 0.0179642 | Preprocessor1_Model10 |
5 | 30 | rmse | standard | 4504.6791618 | 5 | 247.8567109 | Preprocessor1_Model11 |
5 | 30 | rsq | standard | 0.8565087 | 5 | 0.0178983 | Preprocessor1_Model11 |
6 | 30 | rmse | standard | 4531.1114752 | 5 | 241.6965059 | Preprocessor1_Model12 |
6 | 30 | rsq | standard | 0.8551242 | 5 | 0.0174564 | Preprocessor1_Model12 |
3 | 32 | rmse | standard | 4499.8684604 | 5 | 249.1454239 | Preprocessor1_Model13 |
3 | 32 | rsq | standard | 0.8574502 | 5 | 0.0178746 | Preprocessor1_Model13 |
4 | 32 | rmse | standard | 4485.3582706 | 5 | 254.1382742 | Preprocessor1_Model14 |
4 | 32 | rsq | standard | 0.8573791 | 5 | 0.0181175 | Preprocessor1_Model14 |
5 | 32 | rmse | standard | 4499.0556140 | 5 | 249.3251531 | Preprocessor1_Model15 |
5 | 32 | rsq | standard | 0.8567260 | 5 | 0.0178276 | Preprocessor1_Model15 |
6 | 32 | rmse | standard | 4517.2728936 | 5 | 246.0942965 | Preprocessor1_Model16 |
6 | 32 | rsq | standard | 0.8558118 | 5 | 0.0176314 | Preprocessor1_Model16 |
3 | 34 | rmse | standard | 4494.1781361 | 5 | 247.5034820 | Preprocessor1_Model17 |
3 | 34 | rsq | standard | 0.8579955 | 5 | 0.0177320 | Preprocessor1_Model17 |
4 | 34 | rmse | standard | 4477.2679787 | 5 | 252.1122315 | Preprocessor1_Model18 |
4 | 34 | rsq | standard | 0.8578414 | 5 | 0.0179604 | Preprocessor1_Model18 |
5 | 34 | rmse | standard | 4498.0367395 | 5 | 255.1573410 | Preprocessor1_Model19 |
5 | 34 | rsq | standard | 0.8565994 | 5 | 0.0183137 | Preprocessor1_Model19 |
6 | 34 | rmse | standard | 4519.5092097 | 5 | 249.7806274 | Preprocessor1_Model20 |
6 | 34 | rsq | standard | 0.8555612 | 5 | 0.0179683 | Preprocessor1_Model20 |
3 | 36 | rmse | standard | 4498.9937117 | 5 | 253.1112587 | Preprocessor1_Model21 |
3 | 36 | rsq | standard | 0.8575007 | 5 | 0.0180595 | Preprocessor1_Model21 |
4 | 36 | rmse | standard | 4482.6950210 | 5 | 257.3142287 | Preprocessor1_Model22 |
4 | 36 | rsq | standard | 0.8575048 | 5 | 0.0183140 | Preprocessor1_Model22 |
5 | 36 | rmse | standard | 4493.8055794 | 5 | 253.7648297 | Preprocessor1_Model23 |
5 | 36 | rsq | standard | 0.8568971 | 5 | 0.0181862 | Preprocessor1_Model23 |
6 | 36 | rmse | standard | 4519.1213683 | 5 | 250.4609414 | Preprocessor1_Model24 |
6 | 36 | rsq | standard | 0.8555345 | 5 | 0.0180142 | Preprocessor1_Model24 |
3 | 38 | rmse | standard | 4518.8059165 | 5 | 249.9243199 | Preprocessor1_Model25 |
3 | 38 | rsq | standard | 0.8566750 | 5 | 0.0179041 | Preprocessor1_Model25 |
4 | 38 | rmse | standard | 4486.1647083 | 5 | 257.0999480 | Preprocessor1_Model26 |
4 | 38 | rsq | standard | 0.8571813 | 5 | 0.0182885 | Preprocessor1_Model26 |
5 | 38 | rmse | standard | 4498.2283354 | 5 | 252.1359709 | Preprocessor1_Model27 |
5 | 38 | rsq | standard | 0.8564944 | 5 | 0.0180877 | Preprocessor1_Model27 |
6 | 38 | rmse | standard | 4517.3020586 | 5 | 253.0717244 | Preprocessor1_Model28 |
6 | 38 | rsq | standard | 0.8554945 | 5 | 0.0182853 | Preprocessor1_Model28 |
Visualisation de l’erreur pour les combinaisons de “mtry” et “min_n” :
ins_results_rf %>%
filter(.metric == "rmse") %>%
mutate(min_n = factor(min_n)) %>%
ggplot(aes(x=mtry, y=mean,col=min_n)) +
geom_line(alpha = 0.5, size = 1.5) +
geom_point() +
labs(title="RMSE for each combinaison of mtry & min_n",y = "rmse")+
theme(plot.title = element_text(hjust = 0.5))
finalement, on va choisir les meilleurs hyper-paramètres pour notre modèle, puis mettre à jour la spécification du modèle d’origine pour créer la spécification de modèle finale.
best_param <- ins_results_rf %>%
filter(.metric == "rmse") %>%
filter(mean == min(mean))
best_param %>%
knitr::kable(align = "c") %>%
kableExtra::kable_styling(full_width = F, position = "left")
mtry | min_n | .metric | .estimator | mean | n | std_err | .config |
---|---|---|---|---|---|---|---|
4 | 34 | rmse | standard | 4477.268 | 5 | 252.1122 | Preprocessor1_Model18 |
best_mtry <- best_param %>% pull(mtry)
best_min_n <- best_param %>% pull(min_n)
best_rmse <- select_best(ins_results_rf_grid,"rmse")
final_rf <- finalize_model(
ins_spec_rf,
best_rmse
)
final_rf
## Random Forest Model Specification (regression)
##
## Main Arguments:
## mtry = 4
## min_n = 34
##
## Computational engine: randomForest
Les hyper-paramètres optimaux sont mtry = 4 et min_n = 34 .
last_fit(final_rf, charges~age + bmi + children + health + region + sex, insurance_data_split) %>%
collect_metrics() %>%
knitr::kable(align = "c") %>%
kableExtra::kable_styling(full_width = F, position = "left")
.metric | .estimator | .estimate | .config |
---|---|---|---|
rmse | standard | 4596.2229943 | Preprocessor1_Model1 |
rsq | standard | 0.8724205 | Preprocessor1_Model1 |
L’optimisation des hyperparamètres a amélioré les preformance,
model_final_rf <- rand_forest(mode = "regression",mtry=best_mtry,min_n = best_min_n) %>%
set_engine("randomForest") %>%
fit(charges ~ age + bmi + children + health +region + sex, data = insurance_train)
model_final_rf_pred <- model_final_rf %>%
predict(new_data = insurance_test) %>%
bind_cols(insurance_test %>% dplyr::select(charges))
model_final_rf_pred %>%
knitr::kable(align = "c") %>%
kableExtra::kable_styling(full_width = F, position = "left") %>%
# kableExtra::kable_paper() %>%
kableExtra::scroll_box(width = "230px", height = "300px")
.pred | charges |
---|---|
9313.516 | 6406.411 |
2825.881 | 2198.190 |
6488.111 | 4687.797 |
2035.577 | 1625.434 |
5258.149 | 3046.062 |
5496.489 | 4949.759 |
7100.562 | 6313.759 |
8053.543 | 6079.672 |
25660.807 | 23568.272 |
38533.203 | 37742.576 |
46657.647 | 47496.494 |
7361.800 | 5989.524 |
2720.721 | 1743.214 |
7055.363 | 5920.104 |
18660.887 | 16577.780 |
2589.512 | 1532.470 |
23737.523 | 21098.554 |
11330.600 | 8026.667 |
17857.613 | 15820.699 |
5998.075 | 5003.853 |
6289.880 | 4646.759 |
13072.438 | 11488.317 |
2343.480 | 1705.624 |
4775.943 | 3385.399 |
16957.235 | 32734.186 |
7062.731 | 6082.405 |
13898.893 | 12815.445 |
3267.267 | 2457.211 |
3034.871 | 1842.519 |
23232.219 | 19964.746 |
9521.560 | 6948.701 |
6228.107 | 5152.134 |
12215.901 | 10407.086 |
10353.571 | 8116.680 |
6263.941 | 4005.423 |
9750.669 | 7419.478 |
41995.343 | 43753.337 |
5326.750 | 4883.866 |
2948.682 | 1639.563 |
3002.279 | 2130.676 |
37745.972 | 37133.898 |
10158.129 | 7147.105 |
3892.773 | 1980.070 |
8940.956 | 8520.026 |
7633.471 | 7371.772 |
4898.152 | 2483.736 |
5305.123 | 5253.524 |
23834.665 | 19515.542 |
4761.666 | 2689.495 |
12532.750 | 24227.337 |
7516.443 | 6710.192 |
18372.217 | 19444.266 |
8560.121 | 7152.671 |
2633.732 | 1832.094 |
43264.498 | 41097.162 |
14419.087 | 13047.332 |
35894.156 | 33750.292 |
13747.911 | 20462.998 |
44097.105 | 46151.124 |
14596.187 | 14590.632 |
11087.346 | 9282.481 |
12631.051 | 9617.662 |
14160.961 | 12928.791 |
46682.437 | 48549.178 |
6449.327 | 4237.127 |
12739.273 | 9625.920 |
12256.187 | 9432.925 |
5383.750 | 3172.018 |
39760.661 | 38746.355 |
11286.481 | 9249.495 |
5172.913 | 20177.671 |
6106.909 | 4151.029 |
11648.594 | 8444.474 |
10782.399 | 8835.265 |
10904.313 | 7421.195 |
5835.489 | 4894.753 |
46091.216 | 47928.030 |
17849.756 | 13937.666 |
15447.220 | 13217.094 |
15185.567 | 13981.850 |
4889.230 | 3554.203 |
2385.165 | 14133.038 |
11791.167 | 10043.249 |
5976.146 | 3180.510 |
7568.008 | 3481.868 |
16560.577 | 16455.708 |
45292.592 | 42303.692 |
6790.604 | 5846.918 |
10633.440 | 8302.536 |
2933.802 | 1261.859 |
10725.511 | 9264.797 |
22349.628 | 19594.810 |
4968.008 | 2727.395 |
10294.723 | 8968.330 |
11107.695 | 9788.866 |
5842.575 | 18804.752 |
6509.857 | 5969.723 |
3816.306 | 2254.797 |
6324.843 | 5926.846 |
37967.233 | 37079.372 |
2151.573 | 1149.396 |
14487.568 | 12731.000 |
12908.331 | 11454.022 |
3870.831 | 2497.038 |
10479.305 | 9563.029 |
5726.918 | 4347.023 |
13383.609 | 12475.351 |
43341.317 | 48885.136 |
11880.460 | 11455.280 |
2534.345 | 27724.289 |
9994.734 | 8413.463 |
11822.284 | 9861.025 |
6356.439 | 5972.378 |
6746.965 | 6196.448 |
14780.395 | 13887.204 |
12753.393 | 10231.500 |
6456.774 | 3268.847 |
45568.914 | 45863.205 |
5887.979 | 3935.180 |
2330.390 | 1646.430 |
10068.375 | 9058.730 |
3264.518 | 2128.431 |
7810.201 | 7256.723 |
2340.970 | 1665.000 |
6346.507 | 3861.210 |
39966.216 | 43943.876 |
14493.389 | 13635.638 |
7954.188 | 7640.309 |
8104.312 | 5594.846 |
2329.208 | 1633.044 |
12894.770 | 10713.644 |
11145.111 | 9182.170 |
5935.874 | 11326.715 |
10164.694 | 9391.346 |
14931.465 | 14410.932 |
3039.612 | 2709.112 |
20035.035 | 20149.323 |
24892.767 | 32787.459 |
3092.736 | 1712.227 |
12485.291 | 12430.953 |
23750.248 | 24667.419 |
4854.198 | 3410.324 |
13271.241 | 12363.547 |
11348.461 | 10156.783 |
12366.547 | 29186.482 |
40995.523 | 40273.645 |
12066.790 | 10976.246 |
12169.318 | 9541.696 |
6388.622 | 5375.038 |
7449.036 | 6113.231 |
5754.314 | 5469.007 |
45830.372 | 43254.418 |
19481.902 | 19539.243 |
3910.800 | 2416.955 |
14284.351 | 14319.031 |
7834.577 | 6933.242 |
12127.417 | 9869.810 |
25600.279 | 24520.264 |
12104.214 | 11848.141 |
6936.738 | 28476.735 |
9464.778 | 9414.920 |
5453.816 | 2842.761 |
39116.387 | 55135.402 |
8417.074 | 7445.918 |
3605.787 | 1621.883 |
6036.647 | 4719.737 |
4322.806 | 1526.312 |
13924.063 | 12323.936 |
36866.129 | 36021.011 |
4381.692 | 2974.126 |
12360.591 | 10601.632 |
1920.126 | 1141.445 |
10344.478 | 8457.818 |
4675.338 | 3392.365 |
11313.690 | 26140.360 |
4349.087 | 2789.057 |
8661.721 | 4877.981 |
20231.658 | 19719.695 |
5488.869 | 5272.176 |
14164.942 | 13063.883 |
5870.150 | 4661.286 |
14895.313 | 14382.709 |
5337.819 | 2473.334 |
6009.825 | 5245.227 |
13913.109 | 13462.520 |
11816.011 | 25333.333 |
6098.964 | 2913.569 |
7182.002 | 6289.755 |
11255.735 | 10096.970 |
9151.591 | 7348.142 |
13473.497 | 9487.644 |
39884.879 | 39047.285 |
45186.663 | 38998.546 |
4094.535 | 12609.887 |
10409.994 | 9500.573 |
18814.759 | 19199.944 |
7492.538 | 4915.060 |
4670.178 | 3378.910 |
6202.462 | 5484.467 |
16693.063 | 16420.495 |
7998.246 | 7986.475 |
10465.608 | 8627.541 |
7905.082 | 4438.263 |
5186.196 | 3987.926 |
13589.414 | 12495.291 |
2744.719 | 1711.027 |
3294.056 | 2020.177 |
46280.671 | 44423.803 |
16365.670 | 13747.872 |
7055.858 | 22493.660 |
42875.918 | 39727.614 |
8743.255 | 17929.303 |
4749.798 | 2480.979 |
46856.250 | 48970.248 |
10390.913 | 8978.185 |
14084.319 | 13204.286 |
16488.154 | 15161.534 |
12500.187 | 11353.228 |
13621.118 | 9748.911 |
6285.489 | 20420.605 |
10942.061 | 8988.159 |
13411.842 | 10493.946 |
40974.398 | 38282.749 |
35832.055 | 34166.273 |
15699.971 | 14254.608 |
12546.335 | 9991.038 |
12556.381 | 11085.587 |
9282.351 | 7623.518 |
4376.375 | 2203.736 |
41636.720 | 40941.285 |
5638.113 | 3989.841 |
9254.732 | 7727.253 |
12632.213 | 10982.501 |
5414.342 | 2899.489 |
20258.507 | 19350.369 |
8847.998 | 7650.774 |
5669.170 | 2850.684 |
5153.549 | 2632.992 |
12303.216 | 9447.382 |
13712.104 | 36910.608 |
39138.877 | 38415.474 |
19613.455 | 20296.863 |
3115.142 | 12890.058 |
42625.993 | 41661.602 |
6959.332 | 7537.164 |
8082.499 | 7162.012 |
7950.227 | 6985.507 |
46357.351 | 47269.854 |
2329.012 | 1633.962 |
39895.039 | 37607.528 |
14673.729 | 30063.581 |
8582.364 | 7337.748 |
35842.905 | 34254.053 |
4575.516 | 3292.530 |
12170.936 | 10959.330 |
25046.491 | 24535.699 |
46576.316 | 47403.880 |
12017.221 | 8534.672 |
13006.688 | 11931.125 |
45957.826 | 46718.163 |
6179.391 | 2464.619 |
7514.225 | 7201.701 |
8435.519 | 22395.744 |
13066.856 | 10325.206 |
modelPerform_model_final = data.frame(
RMSE = RMSE(
model_final_rf_pred$.pred,
model_final_rf_pred$charges
),
R2 = R2(
model_final_rf_pred$.pred,
model_final_rf_pred$charges
)
)
modelPerform_model_final %>%
knitr::kable(align = "c") %>%
kableExtra::kable_styling(full_width = F, position = "left")
RMSE | R2 |
---|---|
4607.556 | 0.8717825 |
model_final_rf_pred %>% ggplot(aes(y = charges, x = .pred)) +
geom_point(col="#336980",size=2,alpha=0.5) +
geom_abline(slope = 1,col="#821f2e")
result_linear_mod = predict(mod.extend, insurance_data)
lm_rmse = RMSE(result_linear_mod, insurance_data$charges)
lm_rsq = perf %>% pull(r.squared)
knn_rmse = ins_summary %>% filter(.metric == "rmse") %>% pull(.estimate)
knn_rsq = ins_summary %>% filter(.metric == "rsq") %>% pull(.estimate)
rf_rsq = modelPerform_model_final %>% pull(R2)
rf_rmse = modelPerform_model_final %>% pull(RMSE)
model <- c("Linear model", "KNN", "Random forest")
R_squared <- c(lm_rsq, knn_rsq, rf_rsq)
RMSE <- c(lm_rmse, knn_rmse, rf_rmse)
data.frame(model, R_squared, RMSE) %>%
knitr::kable(align = "c") %>%
kableExtra::kable_styling(full_width = F, position = "left")
model | R_squared | RMSE |
---|---|---|
Linear model | 0.8721790 | 4327.960 |
KNN | 0.8698845 | 4642.879 |
Random forest | 0.8717825 | 4607.556 |
Les trois modèles permettent d’atteindre presque la même précision, mais on a un grand avantage de l’interprétabilité dans le modèle linéaire.
Pour obtenir encore plus de précision dans ses prédiction, on propose à cette assurance de collecter plus de données sur ses clients afin d’expliquer le comportement de certain individus qu’on a remarquer -dans la phase de visualisation- qu’ils ne suivent pas la même tendance des autres.
Si, même après l’étude, on n’a pas arriver à distinguer ces observations du reste de la population on doit pensée à utiliser des modéles dédiés à ce problème, principalement RLM Robust Linear model qui donne un poid faibles aux points influenceurs.