Introduction

Pour réaliser leurs profit, les compagnies d’assurance doivent percevoir une prime plus élevée que le montant versé à l’assuré. Pour cette raison, les compagnies d’assurance investissent beaucoup de temps, d’efforts et d’argent dans la création de modèles qui prédit avec précision les coûts des soins de santé.
Afin de remplir cette mission, nous allons dans un premier lieu analyser les facteurs qui influencent les charges médicaux et dans un second lieu essayer de construire un modèle adéquat et optimiser ses performances.

Exploration des données

Importation des données

insurance_data <- read.csv("insurance.csv")

Importation des outils nécessaires

library(ggthemes)
library(tidyverse)
library(ggridges)
library(magrittr)
library(gridExtra)
library(patchwork)
library(inspectdf)
library(corrplot)
library(tidymodels)
library(car)
library(caret)
theme_set(theme_bw())

Vue générale des données

glimpse(insurance_data)
## Rows: 1,338
## Columns: 7
## $ age      <int> 19, 18, 28, 33, 32, 31, 46, 37, 37, 60, 25, 62, 23, 56, 27...
## $ sex      <chr> "female", "male", "male", "male", "male", "female", "femal...
## $ bmi      <dbl> 27.900, 33.770, 33.000, 22.705, 28.880, 25.740, 33.440, 27...
## $ children <int> 0, 1, 3, 0, 0, 0, 1, 3, 2, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0...
## $ smoker   <chr> "yes", "no", "no", "no", "no", "no", "no", "no", "no", "no...
## $ region   <chr> "southwest", "southeast", "southeast", "northwest", "north...
## $ charges  <dbl> 16884.924, 1725.552, 4449.462, 21984.471, 3866.855, 3756.6...

De nombreux facteurs qui influencent le montant qu’un assuré paye pour l’assurance médicale ne sont pas sous son contrôle. Néanmoins, il est bon d’avoir une compréhension de ce qu’ils sont. Voici quelques facteurs collectées par l’assurance, qu’on va étudier leurs influence sur le coût des primes d’assurance médicale :
On dispose d’un dataset qui comporte 1338 observations sur 7 variables

  • Age : l’âge des assurés (beneficiaires).
  • Sex : sexe des assurés “male” ou “female”.
  • bmi : indice de masse corporelle, fournissant une compréhension du corps, des poids relativement élevés ou faibles par rapport à la taille, indice objectif du poids corporel (kg / m ^ 2) utilisant le rapport taille / poids.
  • children : nombre d’enfants couverts par l’assurance maladie / nombre de personnes à charge.
  • smoker : est ce que l’assuré fume ou pas.
  • region: la zone résidentielle du bénéficiaire aux États-Unis, nord-est, sud-est, sud-ouest, nord-ouest.
  • charges: Frais médicaux individuels facturés par l’assurance maladie.

On va convertir les variables sex, children, region, smoker au type factor qui correspond aux variables catégorielles :

insurance_data$sex %<>% as.factor()
insurance_data$children %<>% as.factor()
insurance_data$region %<>% as.factor()
insurance_data$smoker %<>% as.factor()

Description des variables

summary(insurance_data)
##       age            sex           bmi        children smoker    
##  Min.   :18.00   female:662   Min.   :15.96   0:574    no :1064  
##  1st Qu.:27.00   male  :676   1st Qu.:26.30   1:324    yes: 274  
##  Median :39.00                Median :30.40   2:240              
##  Mean   :39.21                Mean   :30.66   3:157              
##  3rd Qu.:51.00                3rd Qu.:34.69   4: 25              
##  Max.   :64.00                Max.   :53.13   5: 18              
##        region       charges     
##  northeast:324   Min.   : 1122  
##  northwest:325   1st Qu.: 4740  
##  southeast:364   Median : 9382  
##  southwest:325   Mean   :13270  
##                  3rd Qu.:16640  
##                  Max.   :63770

Discrétisation de la variable “bmi”

insurance_data %<>% mutate(bmi_cat = cut(bmi,
  breaks = c(0, 18.5, 25, 30, 60),
  labels = c("Under Weight", "Normal Weight", "Overweight", "Obese")
))
  • Under Weight: bmi<18.5
  • Normal Weight: 18.5<bmi<25
  • Overweight : 25=<bmi<30
  • Obese: bmi>=30

Description des variables catégorielles

categ_cols <- insurance_data %>% select_if(~ class(.) == "factor")
for (col in names(categ_cols)) {
  # cat(col , ":")
  # print(round(prop.table(table(insurance_data[[col]])),digits=2))
  # cat("\n")
  t <- insurance_data %>%
    group_by_(col) %>%
    summarise(count = n()) %>%
    mutate(frequency = paste0(round(100 * count / sum(count), 0), "%")) %>% 
    # gt::gt() %>% 
    knitr::kable("html", align = "lcc") %>%
    kableExtra::kable_styling(full_width = F, position = "left") %>% 
    print()
}
sex count frequency
female 662 49%
male 676 51%
children count frequency
0 574 43%
1 324 24%
2 240 18%
3 157 12%
4 25 2%
5 18 1%
smoker count frequency
no 1064 80%
yes 274 20%
region count frequency
northeast 324 24%
northwest 325 24%
southeast 364 27%
southwest 325 24%
bmi_cat count frequency
Under Weight 21 2%
Normal Weight 226 17%
Overweight 386 29%
Obese 705 53%

Observations :
- 50% des individus ont au plus 1 enfant, 75% ont au plus 2.
- la domination d’obésité sur l’échantillon étudié.
- 80% des individus sont des fumeurs.
- le jeu de données est équilibrée par rapport au sex et region.

Proportions des modalités dans les variables catégorielles

show_plot(inspect_cat(insurance_data))

Remarque : Les individus ayant plus que 3 enfants et ceux en sous poids sont négligeables dans notre échantillon donc on va pas les prendre en compte pour la comparaison.

Valeur manquantes

count missing
age 0
sex 0
bmi 0
children 0
smoker 0
region 0
charges 0
bmi_cat 0


le jeu de données ne contient aucune valeur manquante !

Visualisation

Matrice de corrélation

Les variables les plus corrélés avec les charges sont “smoker”, “age” et “bmi”.

Visualisation des interactions entre les variables

#CPCOLS <- c("#617A8F", "#000000", "#B57779", "#FAFAFA", "#FFFFFF", "#FFFFFF")
# continuous:
# upper$continuous with lower$continuous String: ‘points’, ‘smooth’, ‘density’, ‘cor’, ‘blank’
# diag$continuous String: "densityDiag", "barDiag", "blankDiag"
# combo:
# upper$combo with lower$combo String: ‘box’, ‘dot’, ‘facethist’, ‘facetdensity’, ‘denstrip’, ‘blank’
# discrete:
# upper$discrete with lower$discrete String: ‘ratio’, ‘facetbar’, ‘blank’
# diag$discrete String: ‘barDiag’, ‘blankDiag’
CPCOLS <- c("#5B8888", "#bfd1d9", "#B57779", "#F4EDCA")

insurance_data %>%
  dplyr::select(age, bmi, smoker, charges) %>%
  GGally::ggpairs(
    lower = list(
      continuous = GGally::wrap("points", col = "#4c86ad",alpha=0.6),
      combo = GGally::wrap("box",
        fill = "white", col ="black"
          
      )
    ),
    upper = list(
      continuous = GGally::wrap("cor", col = "#4c86ad"),
      combo = GGally::wrap("facetdensity",
        col =
          "black"
      )
    ),
    diag = list(
      continuous = GGally::wrap(
        "barDiag",
        fill = "#f5dfb3", col ="black",
        bins = 18
      ),
      discrete = GGally::wrap("barDiag", fill = "#f5dfb3", col ="black")
    )
  )

Observations :

  • Les distributions des variables :
    • La répartition par âge des assurés est relativement la même, sauf les 18 et 19 ans qui ont une population plus élevée.
    • La distribution de bmi est apparemment normale centrée autour de 30.
    • La distribution des charges est négativement asymétrique.
  • On remarque un effet de ces variables sur les charges, qu’on va explorer plus profondément par la suite.
  • Aucune dépendance importante entre : age & bmi, smoker & bmi, age & smoker.

Distribution des charges

p1 <- ggplot(data = insurance_data, aes(x = charges)) +
  geom_histogram(aes(y = ..density..),
    col = "white",
    bins = 25,
    fill = "#90b8c2"
  ) +
  geom_density() +
  labs(title = "Distribution des charges") +
  theme(
    axis.text.y = element_blank(),
    axis.ticks.y = element_blank()
  )
p2 <- ggplot(data = insurance_data, aes(y = charges)) +
  geom_boxplot(fill = "#90b8c2") +
  ggtitle("Charges Boxplot") +
  theme(
    axis.ticks.x = element_blank(),
    axis.text.x = element_blank()
  ) +
  geom_hline(aes(yintercept = mean(charges)),
    linetype = "dashed", col =
      "red"
  )
grid.arrange(p1, p2,
  ncol = 2, widths = c(5, 2)
)

La distribution des charges est biaisée à droite, on ne peut pas juger les valeurs aberrantes à partir de ce box plot. Donc on va effectuer une transformation logarithmique pour centraliser la distribution.

p1 <- ggplot(data = insurance_data, aes(x = log(charges))) +
  geom_histogram(aes(y = ..density..),
    col = "white",
    bins = 25,
    fill = "#90b8c2"
  ) +
  geom_density() +
  labs(title = "Distribution des charges") +
  theme(
    axis.text.y = element_blank(),
    axis.ticks.y = element_blank()
  )
p2 <- ggplot(data = insurance_data, aes(y = log(charges))) +
  geom_boxplot(fill = "#90b8c2") +
  ggtitle("Charges Boxplot") +
  theme(
    axis.text.y = element_blank(),
    axis.ticks.x = element_blank(),
    axis.text.x = element_blank()
  ) +
  geom_hline(aes(yintercept = mean(log(charges))),
    linetype = "dashed", col =
      "red"
  )
grid.arrange(p1, p2,
  ncol = 2, widths = c(5, 2)
)

On peut maintenant assumer qu’il n’y a aucune anomalie.

Interactions entre l’âge,les conditions de poids et le fait de fumer et leurs impact sur les charges médicales

p1 <- insurance_data %>%
  ggplot(aes(x = age, y = charges, col = bmi_cat)) +
  geom_point(alpha = 0.6, size = 2.5)

p2 <- insurance_data %>%
  ggplot(aes(x = age, y = charges, col = smoker)) +
  geom_point(alpha = 0.8,size = 2.5) +
  scale_color_manual(values = c("#e09e8f", "#90b8c2")) +
  geom_rug() +
  geom_smooth() +
  geom_smooth(
    data = filter(insurance_data, smoker == "yes"),
    col = "grey30",
    method = lm,
    se = FALSE
  ) +
  geom_smooth(
    data = filter(insurance_data, smoker == "no"),
    col = "grey30",
    method = lm,
    se = FALSE
  )

grid.arrange(p1, p2, nrow = 1)

Les charges sont liées à l’âge par une relation quasiment linéaire à trois niveaux :

  • un premier groupe qui se caractérise par les charges les plus élevés, il se compose totalement des individus obèses et fumeurs.
  • un second groupe qui se caractérise par les charges les plus faibles, il se compose totalement des individus non fumeurs et une distribution de bmi normale.
  • et un 3ème groupe non homogène qui nécessite plus d’exploration.

On peut également constater que -pour les trois niveaux- plus les clients sont âgés, plus leurs charges sont élevés.

p1 <- insurance_data %>%
  ggplot(aes(x = bmi, y = charges, col = smoker)) +
  geom_point(size = 2.5, alpha = 0.8) +
  scale_color_manual(values = c("#6f9fb3", "#79a346")) +
  theme(legend.position = c(0.08, 0.85)) +
  geom_smooth(
    data = filter(insurance_data, smoker == "yes"),
    col = "blue"
  ) +
  geom_smooth(
    data = filter(insurance_data, smoker == "yes"),
    col = "black",
    method = lm,
    se = FALSE,
    lty = "dashed"
  ) +
  geom_vline(aes(xintercept = 30), linetype = "dashed") +
  geom_abline(aes(intercept = 12800, slope = 615), col = "red")

p2 <- insurance_data %>%
  ggplot(aes(
    x = cut(bmi, breaks = 8),
    y = charges,
    col = smoker
  )) +
  geom_boxplot() +
  scale_color_manual(values = c("#6f9fb3", "#79a346")) +
  labs(x = "intervalles de bmi") +
  guides(col = FALSE) +
  labs(x = element_blank(), y = "\n\n") +
  theme(axis.text.x = element_text(
    size = 8,
    angle = 40,
    hjust = 1
  )) +
  theme(
    axis.text.y = element_blank(),
    axis.ticks.y = element_blank()
  )

p3 <- insurance_data %>%
  filter(smoker == "yes" & bmi_cat != "Under Weight") %>%
  ggplot(aes(x = charges, fill = bmi_cat)) +
  geom_density(alpha = 0.3) +
  scale_color_manual(values = c("#e09e8f", "#90b8c2","#6f9fb3")) +
  theme(
    axis.text.y = element_blank(),
    axis.ticks.y = element_blank()
  ) +
  theme(legend.position = c(0.85, 0.68)) +
  labs(y = "\ndensity\n")
p4 <- insurance_data %>%
  filter(smoker == "no" & bmi_cat != "Under Weight") %>%
  ggplot(aes(x = charges, fill = bmi_cat)) +
  geom_density(alpha = 0.3) +
  theme(
    axis.text.y = element_blank(),
    axis.ticks.y = element_blank()
  ) +
  guides(fill = FALSE) +
  labs(y = "\n\n")
gridExtra::grid.arrange(p1, p2, p3, p4, nrow = 2, heights = c(5, 3))

Pour les non fumeurs, l’obesité n’a aucun impact sur les charges.
Par contre les fumeurs créent presque un nouveau nuage de points séparée des non-fumeurs qui se caractérise par la croissance des charges en fonction du bmi. ce nuage s’éleve fortement lorsque le seuil d’obesité est franchi (bmi=30): Le seuil maximal des charges des non obèses est identique au seuil minimal des charges des obèses.

p1 <- insurance_data %>%
  ggplot(aes(x = bmi, y = charges, col = smoker)) +
  geom_point(size = 3, alpha = 0.8) +
  scale_color_manual(values = c("#b36b56", "#6f9fb3")) +
  geom_density_2d() +
  stat_density_2d(aes(fill = ..level..), geom = "polygon", alpha = 0.35) +
  labs(fill = "density\n") +
  scale_fill_gradient(low = "white", high = "black") +
  theme(
    legend.position = c(0.07, 0.69),
    legend.text = element_blank()
  )
ggExtra::ggMarginal(p1, groupColour = TRUE, groupFill = TRUE)

# ggExtra::ggMarginal(p1,type = "histogram", fill = "grey", col = "black")

On peut distinguer trois centres de densité :
Le premier englobe tous les individus non fumeurs et se caractérise par des charges faible et une répartition normale de bmi. Le 2ème regroupe les individus fumeurs non obèses et ayant des charges relativement élevés.
et le dernier rassemble les individus fumeurs et obèses ayant des charges très élevés.

Création d’une nouvelle variable “Health”

On va créer une variable “health” qui distingue chacun de ces groupes observés :
- “Obese smokers”
- “Non obese smokers”
- “Non smokers”

insurance_data %<>%
  mutate(
    health = case_when(
      smoker == "yes" & bmi_cat == "Obese" ~ "Obese and smoker",
      smoker == "yes" & bmi_cat != "Obese" ~ "Non obese and smoker",
      smoker == "no" ~ "Non smoker"
    )
  )
insurance_data$health %<>% as.factor()

Exploration de “Health”

p1 <- insurance_data %>%
  ggplot(aes(x = charges, fill = health)) +
  geom_density(alpha = 0.3) +
  theme(legend.position = c(0.81, 0.87)) +
  theme(
    axis.text.y = element_blank(),
    axis.ticks.y = element_blank()
  ) +
  labs(y = "density\n")
p2 <- insurance_data %>%
  ggplot(aes(y = charges, x = age, col = health)) +
  geom_point() +
  scale_color_manual(values = c("#d6907c","#93c754", "#77a2b5")) +
  geom_smooth() +
  geom_smooth(
    aes(group = health),
    col = "black",
    method = lm,
    se = FALSE,
    lty = "dashed"
  ) +
  guides(col = FALSE)
p3 <- insurance_data %>%
  ggplot(aes(x = bmi, y = charges, col = health)) +
  geom_point(size = 3, alpha = 0.5) +
  scale_color_manual(values = c("#d6907c","#93c754", "#77a2b5")) +
  stat_ellipse() +
  geom_rug() +
  guides(col = FALSE)
p4 <- insurance_data %>%
  ggplot(aes(x = age, y = charges, col = health)) +
  geom_point(
    size = 4,
    alpha = 0.9,
    shape = "O"
  ) +
  scale_color_manual(values = c("#d6907c","#93c754", "#77a2b5")) +
  stat_ellipse() +
  geom_rug() +
  guides(col = FALSE) +
  theme(
    axis.text.y = element_blank(),
    axis.ticks.y = element_blank()
  ) +
  labs(y = "\n\n")
grid.arrange(p1, p2, p3, p4, nrow = 2)

On peut maintenant constater que les individus obèses et non fumeurs sont moins chargés que les individus non obèses mais fumeurs, En effet ces derniers sont tous situés dans le deuxième niveau du nuage de points des charges en fonction de l’âge. l’impact d’obesité sur les charges est plus important que l’impact de fumer.
Par contre les individus non fumeurs ( obèses et non obèses ) sont majoritairement situés dans le premier niveau mais certains sont dispersé dans le deuxième ce qui nous pousse a chercher le facteur qui cause cette difference.
On conclut alors que le facteur d’obesité sépare les individus fumeurs en deux groupes, un groupe obèse qui est significativement plus chargé que le groupe non obèse, mais ne nous permet pas de tirer une conclusion a propos des individus non fumeurs.

L’impact du genre sur les frais médicaux

p1 <- ggplot(data = insurance_data, aes(sex, charges)) +
  geom_boxplot(fill = c("#cfe6b8", "#8bb054")) +
  labs(x = element_blank()) +
  ggtitle("Boxplot of Medical Charges by Gender")

p2 <- ggplot(insurance_data, aes(reorder(sex, charges), charges)) +
  geom_bar(
    fill = c("#cfe6b8", "#8bb054"),
    col="black",
    position = "dodge",
    stat = "summary",
    fun = "mean"
  ) +
  labs(x = element_blank(), y = element_blank()) +
  ggtitle("Barplot of Medical Charges by Gender")
p1 + p2

On remarque que les frais médicaux sont plus élevés chez les hommes par rapport aux femmes
On se demande si la majorité des fumeurs sont des hommes ?

p1 <- insurance_data %>%
  ggplot(aes(x = smoker, fill = sex)) +
  geom_bar(position = "dodge") +
  theme(legend.position = c(0.8, 0.88)) +
  ggtitle("Count of patients by Gender")+
  scale_fill_manual(values = c("#e09e8f", "#90b8c2")) 
p2 <- insurance_data %>%
  ggplot(aes(sex, charges, fill = smoker)) +
  geom_bar(
    position = "dodge",
    stat = "summary",
    fun = "mean"
  ) +
  labs(x = element_blank(), y = "\n\n\ncharges") +
  ggtitle("Barplot of Medical Charges by Gender")+
  scale_fill_manual(values = c("#e09e8f", "#90b8c2")) 
p3 <- insurance_data %>%
  ggplot(aes(y = charges, x = smoker, fill = smoker)) +
  geom_violin(width = 1.7) +
  geom_boxplot(width = 0.28) +
  guides(fill = FALSE) +
  labs(x = element_blank(), y = element_blank()) +
  ggtitle("Boxplot of Medical Charges by Gender") +
  scale_fill_manual(values = c("#e09e8f", "#90b8c2")) +
  facet_wrap(~sex)

p1 + p2 + p3

On peut maintenant confirmer notre supposition.
et en comparent les feumeurs des deux genres, on trouve que la différence des frais médicaux entres les hommes et les femmes peut être expliquée par deux facteurs :

  • Les hommes fument plus que les femmes (graphe 1)
  • les charges chez les fumeurs sont plus élevée que les charges chez les fumeuses (graphe 2 & 3)

Analyse régionale

p1 <- ggplot(insurance_data, aes(charges,col = region),col = c("#a6bbff" ,"#ffc99c" ,"#f5dfb3", "#e6a1bc")) +
  geom_density( fill = "grey", alpha = 0.05) +
  ggtitle("Density of Medical Charges per Region") +
  theme(legend.position = c(0.82, 0.82)) +
  theme(axis.text.y = element_blank()) +
  labs(y = "density\n")
p2 <- ggplot(insurance_data, aes(reorder(region, charges), charges)) +
  geom_bar(
    fill = c("#c2aebb" , "#f5dfb3","#e09e8f","#90b8c2"),
    position = "dodge",
    stat = "summary",
    fun = "mean"
  ) +
  labs(x = "", y = "\ncharges") +
  ggtitle("Average of Medical Charges per Region")
p1 + p2

On constate que les frais médicaux sont également distribués sur les régions avec une légère augmentation dans le sud-est. Cherchons la cause dérrière cette différence!

Tableau de contingence

cat("% de chaque chaque categorie de santé par région :\n")
## % de chaque chaque categorie de santé par région :
round(100 * prop.table(table(
  insurance_data$health,
  insurance_data$region
), margin = 2), digits = 2)
##                       
##                        northeast northwest southeast southwest
##   Non obese and smoker     11.73     10.77      9.07      7.38
##   Non smoker               79.32     82.15     75.00     82.15
##   Obese and smoker          8.95      7.08     15.93     10.46
cat("\n\n % des fumeurs par région :\n")
## 
## 
##  % des fumeurs par région :
round(100 * prop.table(table(
  insurance_data$smoker,
  insurance_data$region
), margin = 2), digits = 2)
##      
##       northeast northwest southeast southwest
##   no      79.32     82.15     75.00     82.15
##   yes     20.68     17.85     25.00     17.85
cat("\n\n % de chaque chaque categorie d'obésité par région :\n")
## 
## 
##  % de chaque chaque categorie d'obésité par région :
round(100 * prop.table(
  table(insurance_data$bmi_cat,
        insurance_data$region),
  margin = 2
), digits = 2)
##                
##                 northeast northwest southeast southwest
##   Under Weight       3.09      2.15      0.00      1.23
##   Normal Weight     22.53     19.38     11.26     15.08
##   Overweight        30.25     32.92     21.98     31.08
##   Obese             44.14     45.54     66.76     52.62
ggpubr::ggballoonplot(as.data.frame(table(dplyr::select(
  insurance_data, region, bmi_cat
))),
fill = "value") +
  scale_fill_viridis_c(option = "D")

library(FactoMineR)
library(factoextra)
res.ca <-
  CA(table(dplyr::select(insurance_data, region, health)), graph = FALSE)
fviz_ca_biplot(res.ca, repel = TRUE)

Grâce aux tableaux de contingence et au graphique ci-dessus, on peut éxpliquer les charges élevés dans le sud-est par la présence d’une proportion importante des patients fumeurs et obèses par rapport aux autres régions.
et à l’aide du graphe de l’AFC on peut visualiser la dépendance entre les modalités de “health” et “region”, on voit du premier axe qui représente 83.5% de l’information que la région “southeast” se distingue du reste par une dépendance remarquable avec la classe “Obese and smoker”.

Est ce que le nombre d’enfants de l’assuré influence ses charges ?

p1 <- ggplot(filter(insurance_data, children %in% c(0:3))) +
  # geom_density(aes(charges,fill = children), alpha = 0.4) +
  # theme(legend.position = c(0.85, 0.78)) +
  # theme(axis.text.y=element_blank(), axis.ticks.y=element_blank())+
  # labs(fill="children\n")
  geom_density_ridges(
    aes(
      x = charges,
      y = children,
      fill = children,
      height = stat(density)
    ),
    alpha=0.6,
    stat = "density",
    scale = 2.25
  ) +
  # alpha = 1,rel_min_height = 0.002, quantile_lines = TRUE) +
  scale_fill_manual(values = c("#f5dfb3","#e09e8f","#b36b56","#8bbfaf")) +
  # scale_x_continuous(expand = c(0, 0)) +
  scale_y_discrete(expand = expand_scale(mult = c(0.07, .7))) +
  coord_cartesian(clip = "off") +
  theme(
    legend.position = c(0.88, 0.82),
    axis.text.y = element_blank(),
    axis.ticks.y = element_blank()
  ) +
  # theme_ridges(font_size = 13, grid = TRUE)+
  labs(y = "density")
p2 <- insurance_data %>%
  ggplot(aes(reorder(children, charges), charges)) +
  geom_bar(
    position = "dodge",
    stat = "summary",
    fun = "mean",
    fill = c("#f5dfb3","#e09e8f","#b36b56","#8bbfaf")
  ) +
  scale_x_discrete(limits = factor(c(0:3))) +
  labs(x = "children", y = "\n\ncharges\n",title="Average of medical charges by number of children")
p1 + p2

Les charges sont plus élevés chez les parents.
Les charges des non parents se caractérisent par une densité particulière.

On va visualiser cette différence avec une nouvelle variable “parent” de deux modalités “yes” et “no”.

insurance_data$parent <-
  as.factor(ifelse(insurance_data$children == 0, "no", "yes"))

Différence entre les charges chez les assurés parents et non parents

p1 <- insurance_data %>%
  ggplot(aes(y = charges, x = parent, fill = parent)) +
  geom_violin(width = 0.8,alpha=0.7) +
  geom_boxplot(col = "grey30", alpha = 0.6, width = 0.13) +
  theme(legend.position = c(0.9, 0.9)) +
  labs(x = element_blank(), y = element_blank()) +
  scale_fill_manual(values = c("#a7c76c","#d19bbb"))
p <- insurance_data %>%
  ggplot() +
  geom_point(aes(x = age, y = charges, col = parent),
             size = 1,
             alpha = 0.8) +
  scale_color_manual(values=c("#79a346","#d19bbb"))+
  theme(axis.text.y = element_blank(),
        axis.ticks.y = element_blank())

p2 <- ggExtra::ggMarginal(p + theme(legend.position = "none"),

  groupColour = TRUE,
  groupFill = TRUE
)
grid.arrange(p1, p2,
  ncol = 2, widths = c(2, 3)
)

La distinction entre les charges des parents et non parents est plus significatif dans le premier niveau.
La majorité des individus âgés de 30 à 50 ans sont parents, alors que les jeunes et vieillards ne sont pas parents d’enfants couverts par l’assurance.

ad <- 0.9
p1 <- insurance_data %>%
  filter(parent == "yes") %>%
  ggplot(aes(x = age, y = charges, col = health)) +
  scale_color_manual(values = c("#79a346", "#245373","#964a74"))+
  geom_point(size = 0.8, alpha = 0.7) +
  stat_density_2d(
    aes(fill = ..level..),
    geom = "polygon",
    alpha = 0.5,
    adjust = ad
  ) +
  labs(fill = "densité\n") +
  scale_fill_gradient(low = "white", high = "black") +
  guides(fill = FALSE, col = FALSE)
p3 <- ggExtra::ggMarginal(p1, groupColour = TRUE, groupFill = TRUE)
p2 <- insurance_data %>%
  filter(parent == "no") %>%
  ggplot(aes(x = age, y = charges, col = health)) +
  geom_point(size = 0.8, alpha = 0.7) +
  scale_color_manual(values = c("#79a346", "#245373","#964a74")) +
  stat_density_2d(
    aes(fill = ..level..),
    geom = "polygon",
    alpha = 0.5,
    adjust = ad
  ) +
  labs(y = "\n") +
  scale_fill_gradient(low = "white", high = "black") +
  guides(fill = FALSE, col = guide_legend(order=1))+
  theme(legend.position = "left")
p4 <- ggExtra::ggMarginal(p2, groupColour = TRUE, groupFill = TRUE)
grid.arrange(p3, p4, nrow = 1, widths = c(6, 9))

Due à la relation entre l’âge et les charges :
Les parents ont des charges centré sur la moyenne de chaque niveau, contrairement aux non parents qui rassemblent à la fois des charges élevés (des plus agés) et des charges bas (des moins agés)
Cette différence est moins remarquable quand les charges sont élevés à cause de la dominations des charges individuels sur les charges des enfants qui deviennent négligeables.

Modélisation

Modéle linéaire

Initialisation du modèle

On va commencer par un modèle linéaire simple en utilisant les deux variables age et health. Centraliser la variables âge (pour permettre l’interpretation de l’intercept).

insurance_data %<>% mutate(centred_age = age - 18)
insurance_data$health = relevel(insurance_data$health, ref = "Non smoker")
mod.basic = lm(charges ~ centred_age + health,
               data = insurance_data,
               x = TRUE,
               y = TRUE)
summary(mod.basic)
## 
## Call:
## lm(formula = charges ~ centred_age + health, data = insurance_data, 
##     x = TRUE, y = TRUE)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -5585.7 -1953.5 -1324.8  -394.4 24493.8 
## 
## Coefficients:
##                             Estimate Std. Error t value Pr(>|t|)    
## (Intercept)                 2693.654    234.102   11.51   <2e-16 ***
## centred_age                  268.437      8.816   30.45   <2e-16 ***
## healthNon obese and smoker 13343.999    420.794   31.71   <2e-16 ***
## healthObese and smoker     33334.018    401.955   82.93   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 4527 on 1334 degrees of freedom
## Multiple R-squared:  0.8606, Adjusted R-squared:  0.8603 
## F-statistic:  2745 on 3 and 1334 DF,  p-value: < 2.2e-16
mod.basic %>%
  augment() %>%
  ggplot(aes(x = centred_age, y = charges)) +
  geom_point(aes(col = health)) +
  scale_color_manual(values = c("#79a346", "#245373","#964a74")) +
  geom_line(aes(y = .fitted)) +
  geom_segment(aes(xend = centred_age, yend = .fitted), col = "grey70") +
  facet_grid( ~ health)

mod.basic %>%
  augment() %>%
  ggplot(aes(x = .fitted, y = charges)) +
  geom_point(shape = 1, aes(col = health)) +
  scale_color_manual(values = c("#79a346", "#245373","#964a74")) +
  geom_abline(slope = 1, lty = "dashed")

Sur ce premier modèle:
- on a obtenu un R-carré ajusté de 0.86, ce qui signifie que 86% de la variation des charges pourrait être expliquée par les deux variables indépendantes age et health.
- on a également pu observer que toutes les variables indépendantes sont des prédicteurs statistiquement significatifs des frais médicaux (p.value inférieure à 0,05 <- niveau de signification).
- La moyenne des charges pour un assuré de 18 ans non fumeur est 2693.654.
- Une année de plus dans l’âge augmente les charges par 268.437 pour la même catégorie de health.
- Par rapport aux non fumeurs, en contrôlant l’âge :
     Les fumeurs non obèses ont en moyenne 13343.999 plus de charges.
     Les fumeurs obèses ont en moyenne 33334.018 plus de charges.

Performance du modèle initial :

mod.basic %>%
  glance() %>%
  dplyr::select(r.squared,
         adj.r.squared,
         AIC,
         BIC,
         statistic,
         p.value,
         df.residual,
         nobs) %>% 
  knitr::kable(align = "c") %>% 
  kableExtra::kable_styling(full_width = F, position = "left")
r.squared adj.r.squared AIC BIC statistic p.value df.residual nobs
0.8605842 0.8602707 26329.01 26355 2744.834 0 1334 1338

verification des hypothèses de la régression pour le modèle initial

Vérification de la linéarité des données

plot(mod.basic, which=1)

On peut assumer une relation linéaire entre les prédicteurs et la variable cible.

Verification de l’homogénité de la variance par spread-location plot

plot(mod.basic, which=3)

On costance que la variabilité des résidus en fonction des valeurs prédites a une forme quadratique pour un groupe d’invidus, ce qui indique l’hétéroscédasticité (variances non constantes des erreurs)

ce fait on amène à effectuer une transformation sur la variable âge.

Amélioration du modèle

mod.age2 = lm(charges ~ I(age ^ 2) + health, data = insurance_data)
summary(mod.age2)
## 
## Call:
## lm(formula = charges ~ I(age^2) + health, data = insurance_data)
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -4980.7 -1975.6 -1425.1  -264.2 23826.2 
## 
## Coefficients:
##                             Estimate Std. Error t value Pr(>|t|)    
## (Intercept)                2.553e+03  2.368e+02   10.79   <2e-16 ***
## I(age^2)                   3.362e+00  1.098e-01   30.62   <2e-16 ***
## healthNon obese and smoker 1.339e+04  4.199e+02   31.90   <2e-16 ***
## healthObese and smoker     3.331e+04  4.010e+02   83.06   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 4516 on 1334 degrees of freedom
## Multiple R-squared:  0.8612, Adjusted R-squared:  0.8609 
## F-statistic:  2760 on 3 and 1334 DF,  p-value: < 2.2e-16

Performance :

mod.age2 %>%
  glance() %>%
  dplyr::select(r.squared,
         adj.r.squared,
         AIC,
         BIC,
         statistic,
         p.value,
         df.residual,
         nobs) %>% 
  knitr::kable(align = "c") %>% 
  kableExtra::kable_styling(full_width = F, position = "left")
r.squared adj.r.squared AIC BIC statistic p.value df.residual nobs
0.8612288 0.8609168 26322.81 26348.8 2759.65 0 1334 1338
plot(mod.age2, which = 1)

plot(mod.age2, which = 3)

On peut assumer que les deux premières hypothèses sont bien vérifiés pour ce deuxième modèle.
On va essayer d’améliorer encore la précision en ajoutant d’autres variables explicative :

mod.extend = lm(charges ~ I(age ^ 2) + health + smoker:bmi + children + region +
                  sex,
                data = insurance_data)
summary(mod.extend)
## 
## Call:
## lm(formula = charges ~ I(age^2) + health + smoker:bmi + children + 
##     region + sex, data = insurance_data)
## 
## Residuals:
##    Min     1Q Median     3Q    Max 
##  -3271  -1496  -1217   -849  24210 
## 
## Coefficients:
##                              Estimate Std. Error t value Pr(>|t|)    
## (Intercept)                 2189.2344   747.2089   2.930 0.003449 ** 
## I(age^2)                       3.3456     0.1068  31.322  < 2e-16 ***
## healthNon obese and smoker  1461.1728  1960.6927   0.745 0.456263    
## healthObese and smoker     16616.0387  2635.5680   6.305 3.93e-10 ***
## children1                    886.3802   302.7192   2.928 0.003469 ** 
## children2                   1812.5903   335.3965   5.404 7.71e-08 ***
## children3                   1529.6219   393.1487   3.891 0.000105 ***
## children4                   4098.2667   890.5092   4.602 4.58e-06 ***
## children5                   2174.2555  1046.7760   2.077 0.037985 *  
## regionnorthwest             -394.3059   342.6695  -1.151 0.250068    
## regionsoutheast             -895.1742   344.7422  -2.597 0.009518 ** 
## regionsouthwest            -1189.3192   343.6165  -3.461 0.000555 ***
## sexmale                     -522.6149   239.4042  -2.183 0.029213 *  
## smokerno:bmi                  14.6839    23.0145   0.638 0.523564    
## smokeryes:bmi                486.1988    71.3217   6.817 1.41e-11 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 4352 on 1323 degrees of freedom
## Multiple R-squared:  0.8722, Adjusted R-squared:  0.8708 
## F-statistic: 644.8 on 14 and 1323 DF,  p-value: < 2.2e-16

En ajoutant d’autres prédicteurs le r-carrée ajustée augmente à 87.08%.

On va tester l’importance des variables sex et region dans la performance du modèle
H0 : l’ajout de la variable n’a aucune importance

anova(lm(charges ~ I(age ^ 2) + health + smoker:bmi + children + sex , data = insurance_data) ,
      mod.extend)
## Analysis of Variance Table
## 
## Model 1: charges ~ I(age^2) + health + smoker:bmi + children + sex
## Model 2: charges ~ I(age^2) + health + smoker:bmi + children + region + 
##     sex
##   Res.Df        RSS Df Sum of Sq      F   Pr(>F)   
## 1   1326 2.5326e+10                                
## 2   1323 2.5062e+10  3 263504216 4.6366 0.003128 **
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
anova(
  lm(charges ~ I(age ^ 2) + health + smoker:bmi + children + region , data = insurance_data) ,
  mod.extend
)
## Analysis of Variance Table
## 
## Model 1: charges ~ I(age^2) + health + smoker:bmi + children + region
## Model 2: charges ~ I(age^2) + health + smoker:bmi + children + region + 
##     sex
##   Res.Df        RSS Df Sum of Sq      F  Pr(>F)  
## 1   1324 2.5153e+10                              
## 2   1323 2.5062e+10  1  90273982 4.7654 0.02921 *
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Avec un niveau de risque 5%, On rejete l’hypothèse H0 pour les deux cas.
Donc, on garde les deux variables dans le modèle.

Vérification des hypothèses pour le nouveau modèle amélioré

Linéarité et normalité des résidus:

par(mfrow = c(1, 3))
for (i in c(1, 3, 2)) {
  plot(mod.extend, which = i)
}

la distribution des résidus n’est pas normale.
l’hypothèse de linéarité est vérifiée et la variance semble constante, mais on va effectuer un test pour s’assurer.

Hétéroscedasticité:

library(lmtest)
bptest(mod.extend) # tester l'existence de l'hétéroscédasticité lineaire
## 
##  studentized Breusch-Pagan test
## 
## data:  mod.extend
## BP = 16.844, df = 14, p-value = 0.2646
ncvTest(mod.extend)# tester l'existence de l'hétéroscédasticité non lineaire
## Non-constant Variance Score Test 
## Variance formula: ~ fitted.values 
## Chisquare = 8.567156, Df = 1, p = 0.0034228

Il n’y a aucun signe d’hétéroscédasticité.

par(mfrow = c(2, 2))
for (i in c(4:6)) {
  plot(mod.extend, which = i)
}
influencePlot(mod.extend, main = "Influence Plot")

##          StudRes         Hat        CookD
## 322   3.70355755 0.043793647 4.148127e-02
## 517   5.64758608 0.007301576 1.528300e-02
## 1013  4.32167210 0.046084448 5.935982e-02
## 1048 -0.22975456 0.086606152 3.339172e-04
## 1086  0.09609113 0.077335829 5.163424e-05
## 1301  5.54943002 0.018034681 3.687613e-02

On a absence de points très influents dans le modèle, malgré l’existence des points ayant un résidu important mais ne sont pas extrêmes par rapport aux prédicteurs, leurs faible hat-values ne leurs permet pas d’influencer la droite de régression, ainsi que des points extrêmes ayant un faible résidu.

les points ayant les plus grandes valeurs de cook’s distance (mesure d’influence basée sur le leverage et le résidu) sont plus caractérisées par un grand résidu (ordre de grandeur > 3) => l’erreur contribue plus dans l’influence de ces points.

perf=mod.extend %>%
  glance() %>%
  dplyr::select(r.squared,
         adj.r.squared,
         AIC,
         BIC,
         statistic,
         p.value,
         df.residual,
         nobs)
perf%>% knitr::kable(align = "c") %>% 
        kableExtra::kable_styling(full_width = F, position = "left")
r.squared adj.r.squared AIC BIC statistic p.value df.residual nobs
0.872179 0.8708264 26234.83 26318.01 644.8153 0 1323 1338

Contribution des variables dans le modèle :

library(relaimpo)
mod.extend.shapley <-
  calc.relimp(lm(charges ~ I(age ^ 2) + health + bmi + children + region +
                   sex,
                 data = insurance_data),
              type = "lmg") 
library(caret)
mod.extend.shapley$lmg %>%
  sort(T) %>% 
  data.frame()  -> temp
names(temp) <- c("Shaply_value")
p = temp %>% 
  cbind(name = rownames(.)) %>% 
  ggplot()+
  geom_col(aes(reorder(name, -Shaply_value), Shaply_value), stat = "identity",
           fill = c("#4a9687","#c2aebb","#f5dfb3","#e09e8f","#cfe6b8","#f2d696"), col="black")+
  ggtitle("Relative Importance of Predictors")+
  theme(plot.title = element_text(hjust = 0.5))+
  labs(x="Predictor Labels", y="Shapley Value Regression")
grid.arrange(p ,tableGrob(round(temp,4)), ncol = 2, widths = c(4, 2))

Pour dériver l’importance des variables prédicteurs dans ce modèle. on va utiliser une méthode statistique appelée régression des valeurs shapley qui est une solution issue du concept de théorie des jeux. Son objectif est de répartir équitablement l’importance des prédicteurs dans l’analyse de régression. Étant donné n nombre de variables indépendantes (IV), on va exécuter toutes les combinaisons de modèles de régression linéaire en utilisant cette liste de (IV) par rapport à la variable dépendante (DV) et obtenir le R-carré de chaque modèle. Pour obtenir la mesure de l’importance de chaque variable indépendante (IV), la contribution moyenne au R-carré total de chaque (IV) est calculée en décomposant le R-carré total et en calculant la contribution marginale proportionnelle de chaque (IV).

Supposant qu’on a 2 (IV) A et B et une variable dépendante Y. on doit construire 3 modèles comme suit: 1) Y ~ A 2) Y ~ B 3) Y ~ A + B et chaque modèle aurait leur R-carré respectif .

Pour obtenir la valeur Shapley de A, on va décomposer le r-carré du troisième modèle et calculer la contribution marginale de l’attribut A.

Valeur Shapley (A) = \(\frac{(R^2 (AB) - R^2 (B)) + R^2 (A)}{2}\)

on a utilisé la fonction calc.relimp() du package relaimpo pour déterminer la valeur Shapley de nos prédicteurs.

Les scores de Shapley Value de chaque attribut montrent leur contribution marginale au r carré global (0.8722) du deuxième modèle. on peut donc conclure que, sur la variance totale de 87,22% expliquée par notre modèle, presque 75% de celle-ci est due à l’attribut “health”.

Modèle additif généralisé

library(mgcv)
# Build the model
mod.gam <- gam(charges ~ s(age) + health + s(bmi) + children + region +
                 sex,
               data = insurance_data)
summary(mod.gam)
## 
## Family: gaussian 
## Link function: identity 
## 
## Formula:
## charges ~ s(age) + health + s(bmi) + children + region + sex
## 
## Parametric coefficients:
##                            Estimate Std. Error t value Pr(>|t|)    
## (Intercept)                  8335.7      314.7  26.486  < 2e-16 ***
## healthNon obese and smoker  13603.2      430.7  31.584  < 2e-16 ***
## healthObese and smoker      33216.6      411.2  80.780  < 2e-16 ***
## children1                     970.9      320.6   3.028 0.002508 ** 
## children2                    1844.7      355.5   5.189 2.44e-07 ***
## children3                    1593.2      408.7   3.899 0.000102 ***
## children4                    4093.3      908.3   4.507 7.17e-06 ***
## children5                    2090.6     1067.6   1.958 0.050414 .  
## regionnorthwest              -327.5      347.8  -0.942 0.346553    
## regionsoutheast              -786.3      349.2  -2.252 0.024512 *  
## regionsouthwest             -1162.8      348.6  -3.335 0.000876 ***
## sexmale                      -481.9      242.5  -1.987 0.047130 *  
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Approximate significance of smooth terms:
##          edf Ref.df       F p-value    
## s(age) 3.083  3.832 246.331  <2e-16 ***
## s(bmi) 3.322  4.206   2.001   0.085 .  
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## R-sq.(adj) =  0.867   Deviance explained = 86.9%
## GCV = 1.9721e+07  Scale est. = 1.945e+07  n = 1338
mod.gam %>%
  glance() %>%
  mutate(adj.r.squared = 0.867) %>%
  dplyr::select(adj.r.squared, AIC, BIC, nobs) %>% 
  knitr::kable(align = "c") %>% 
  kableExtra::kable_styling(full_width = F, position = "left")
adj.r.squared AIC BIC nobs
0.867 26273.5 26374.39 1338

On n’observe aucune amélioration dans le modele additif généralisé par rapport au modéle précèdent.Donc on préfère le modèle linéaire pour des raisons de simplicité et interpretabilité.

K-nearest neighbour

Modèle initial

Afin de pouvoir évaluer le modèle knn, on va diviser notre jeu de données en deux parties : 80% pour le train set, et 20% pour le test set.

set.seed(12345)
insurance_data_split <-
  initial_split(insurance_data, prop = .8, strata = "charges")
insurance_train <- training(insurance_data_split)
insurance_test <- testing(insurance_data_split)

A l’aide du package parsnip, On va instancier un modéle knn pour la regression et l’entrainer sur la train set.

library(parsnip)
library(kknn)
model_knn <- nearest_neighbor(mode = "regression") %>%
  set_engine("kknn") %>%
  fit(charges ~ age + bmi + children + health +
        region, data = insurance_train)

Puis on va utiliser ce modèle pour effectuer les prédictons sur le test set afin d’évaluer sa performance.

model_knn_pred <- model_knn %>%
  predict(new_data = insurance_test) %>%
  bind_cols(insurance_test %>% dplyr::select(charges))

modelPerform = data.frame(
  RMSE = RMSE(model_knn_pred$.pred, model_knn_pred$charges),
  R2 = R2(model_knn_pred$.pred, model_knn_pred$charges)
)
modelPerform %>% 
  knitr::kable(align = "c") %>% 
  kableExtra::kable_styling(full_width = F, position = "left")
RMSE R2
5032.466 0.8481833

Amélioration du modèle

Pour améliorer encore plus la performance, on va essayer de trouver le k optimal.

Pour automatiser les prétraitements, on va commencer par créer une recipe pour standariser les variables numeriques age et bmi, et la combiner avec le modèle knn dans un workflow.

#creating recipe
ins_recipe <-
  recipe(charges ~ age + bmi + children + health + region, data = insurance_train) %>%
  step_scale(age, bmi) %>%
  step_center(age, bmi)
#model specification
ins_spec <-
  nearest_neighbor(weight_func = "rectangular", neighbors = tune()) %>%
  set_engine("kknn") %>%
  set_mode("regression")
#workflow(model+recipe)
ins_wf <- workflow() %>%
  add_recipe(ins_recipe) %>%
  add_model(ins_spec)
ins_wf
## == Workflow ====================================================================
## Preprocessor: Recipe
## Model: nearest_neighbor()
## 
## -- Preprocessor ----------------------------------------------------------------
## 2 Recipe Steps
## 
## * step_scale()
## * step_center()
## 
## -- Model -----------------------------------------------------------------------
## K-Nearest Neighbor Model Specification (regression)
## 
## Main Arguments:
##   neighbors = tune()
##   weight_func = rectangular
## 
## Computational engine: kknn

Pour trouver le k optimale pour le modèle, on va utiliser la méthode de validation du cross-fold avec 5 folds, qui sert à diviser les données sur cinq partie, et à chaque itération réserver une partie pour la validation et utiliser les autres pour l’entrainement, puis calculer la performance en utilisant la moyenne des perfarmances sur les différents folds.

Ici on va chercher le k optimale dans l’intervalle [1,100], et mesurer la performance avec cross-fold validation pour réduire l’erreur due à l’interaction possible entre les variables lors d’une simple division.

set.seed(123456)
ins_vfold <- vfold_cv(insurance_train, v = 5, strata = charges)

gridvals <- tibble(neighbors = seq(1, 100))

ins_results <- ins_wf %>%
  tune_grid(resamples = ins_vfold, grid = gridvals) %>%
  collect_metrics()

ins_results  %>% 
  knitr::kable(align = "c") %>% 
  kableExtra::kable_styling(full_width = T, position = "left")  %>% 
  kableExtra::kable_paper() %>%
  kableExtra::scroll_box( height = "300px")
neighbors .metric .estimator mean n std_err .config
1 rmse standard 6178.9140580 5 173.3195772 Preprocessor1_Model001
1 rsq standard 0.7523178 5 0.0152306 Preprocessor1_Model001
2 rmse standard 5287.3902576 5 182.8131372 Preprocessor1_Model002
2 rsq standard 0.8068960 5 0.0170376 Preprocessor1_Model002
3 rmse standard 5100.9016560 5 202.0380177 Preprocessor1_Model003
3 rsq standard 0.8185064 5 0.0169442 Preprocessor1_Model003
4 rmse standard 4963.0848913 5 227.1004810 Preprocessor1_Model004
4 rsq standard 0.8269390 5 0.0187191 Preprocessor1_Model004
5 rmse standard 4890.3112905 5 253.3037270 Preprocessor1_Model005
5 rsq standard 0.8315022 5 0.0196675 Preprocessor1_Model005
6 rmse standard 4851.6850923 5 255.4585673 Preprocessor1_Model006
6 rsq standard 0.8336937 5 0.0195599 Preprocessor1_Model006
7 rmse standard 4775.7779870 5 252.4518214 Preprocessor1_Model007
7 rsq standard 0.8392938 5 0.0185528 Preprocessor1_Model007
8 rmse standard 4773.4450843 5 255.8713731 Preprocessor1_Model008
8 rsq standard 0.8396510 5 0.0185616 Preprocessor1_Model008
9 rmse standard 4733.5640708 5 256.0109169 Preprocessor1_Model009
9 rsq standard 0.8427674 5 0.0180070 Preprocessor1_Model009
10 rmse standard 4715.0323446 5 263.7646139 Preprocessor1_Model010
10 rsq standard 0.8437694 5 0.0184588 Preprocessor1_Model010
11 rmse standard 4693.7602162 5 264.6281078 Preprocessor1_Model011
11 rsq standard 0.8450957 5 0.0183661 Preprocessor1_Model011
12 rmse standard 4675.7965036 5 266.1327002 Preprocessor1_Model012
12 rsq standard 0.8460284 5 0.0184862 Preprocessor1_Model012
13 rmse standard 4658.2517213 5 264.0859288 Preprocessor1_Model013
13 rsq standard 0.8470701 5 0.0183785 Preprocessor1_Model013
14 rmse standard 4653.9054267 5 267.4354491 Preprocessor1_Model014
14 rsq standard 0.8472995 5 0.0184257 Preprocessor1_Model014
15 rmse standard 4649.7572197 5 274.8265877 Preprocessor1_Model015
15 rsq standard 0.8474079 5 0.0189709 Preprocessor1_Model015
16 rmse standard 4663.8376805 5 270.9116049 Preprocessor1_Model016
16 rsq standard 0.8465355 5 0.0188114 Preprocessor1_Model016
17 rmse standard 4646.5558310 5 265.4033259 Preprocessor1_Model017
17 rsq standard 0.8476093 5 0.0184269 Preprocessor1_Model017
18 rmse standard 4648.9109798 5 258.6419467 Preprocessor1_Model018
18 rsq standard 0.8476848 5 0.0178239 Preprocessor1_Model018
19 rmse standard 4640.7203726 5 260.4874734 Preprocessor1_Model019
19 rsq standard 0.8482083 5 0.0178772 Preprocessor1_Model019
20 rmse standard 4629.8374928 5 263.9582032 Preprocessor1_Model020
20 rsq standard 0.8490991 5 0.0178641 Preprocessor1_Model020
21 rmse standard 4629.1660454 5 260.0895097 Preprocessor1_Model021
21 rsq standard 0.8491335 5 0.0176451 Preprocessor1_Model021
22 rmse standard 4634.4062304 5 262.3567633 Preprocessor1_Model022
22 rsq standard 0.8489441 5 0.0177543 Preprocessor1_Model022
23 rmse standard 4647.4361189 5 261.3037340 Preprocessor1_Model023
23 rsq standard 0.8481788 5 0.0177749 Preprocessor1_Model023
24 rmse standard 4655.3149210 5 258.6656146 Preprocessor1_Model024
24 rsq standard 0.8476871 5 0.0176269 Preprocessor1_Model024
25 rmse standard 4654.8456030 5 257.9780451 Preprocessor1_Model025
25 rsq standard 0.8477278 5 0.0176100 Preprocessor1_Model025
26 rmse standard 4669.6677147 5 262.0185012 Preprocessor1_Model026
26 rsq standard 0.8467377 5 0.0179200 Preprocessor1_Model026
27 rmse standard 4675.1333349 5 253.9520483 Preprocessor1_Model027
27 rsq standard 0.8467449 5 0.0173183 Preprocessor1_Model027
28 rmse standard 4686.8939256 5 251.7401776 Preprocessor1_Model028
28 rsq standard 0.8461133 5 0.0171720 Preprocessor1_Model028
29 rmse standard 4697.4520238 5 251.6249278 Preprocessor1_Model029
29 rsq standard 0.8456957 5 0.0170900 Preprocessor1_Model029
30 rmse standard 4697.7346408 5 247.4222881 Preprocessor1_Model030
30 rsq standard 0.8460997 5 0.0168078 Preprocessor1_Model030
31 rmse standard 4707.8854025 5 248.8065567 Preprocessor1_Model031
31 rsq standard 0.8457563 5 0.0166483 Preprocessor1_Model031
32 rmse standard 4725.2919175 5 248.3321147 Preprocessor1_Model032
32 rsq standard 0.8451747 5 0.0165526 Preprocessor1_Model032
33 rmse standard 4735.5588714 5 252.6255563 Preprocessor1_Model033
33 rsq standard 0.8450080 5 0.0167783 Preprocessor1_Model033
34 rmse standard 4746.0761373 5 250.6010354 Preprocessor1_Model034
34 rsq standard 0.8450294 5 0.0165461 Preprocessor1_Model034
35 rmse standard 4762.7821750 5 252.4975171 Preprocessor1_Model035
35 rsq standard 0.8446822 5 0.0167139 Preprocessor1_Model035
36 rmse standard 4766.5267641 5 255.1530408 Preprocessor1_Model036
36 rsq standard 0.8452888 5 0.0168055 Preprocessor1_Model036
37 rmse standard 4782.1438923 5 256.0248366 Preprocessor1_Model037
37 rsq standard 0.8448535 5 0.0167685 Preprocessor1_Model037
38 rmse standard 4806.4956334 5 254.7956850 Preprocessor1_Model038
38 rsq standard 0.8440514 5 0.0167226 Preprocessor1_Model038
39 rmse standard 4818.2964146 5 257.1599681 Preprocessor1_Model039
39 rsq standard 0.8439850 5 0.0167531 Preprocessor1_Model039
40 rmse standard 4820.8396510 5 254.8822517 Preprocessor1_Model040
40 rsq standard 0.8447060 5 0.0165704 Preprocessor1_Model040
41 rmse standard 4840.8191847 5 256.0101246 Preprocessor1_Model041
41 rsq standard 0.8439965 5 0.0166757 Preprocessor1_Model041
42 rmse standard 4852.4308016 5 254.5219178 Preprocessor1_Model042
42 rsq standard 0.8438103 5 0.0166872 Preprocessor1_Model042
43 rmse standard 4867.6603943 5 254.1660848 Preprocessor1_Model043
43 rsq standard 0.8441141 5 0.0163858 Preprocessor1_Model043
44 rmse standard 4880.2858251 5 254.0490051 Preprocessor1_Model044
44 rsq standard 0.8442817 5 0.0163405 Preprocessor1_Model044
45 rmse standard 4896.9124752 5 256.7190085 Preprocessor1_Model045
45 rsq standard 0.8441670 5 0.0165102 Preprocessor1_Model045
46 rmse standard 4907.1107505 5 251.0990212 Preprocessor1_Model046
46 rsq standard 0.8442302 5 0.0161694 Preprocessor1_Model046
47 rmse standard 4934.1177870 5 250.3514229 Preprocessor1_Model047
47 rsq standard 0.8434454 5 0.0162412 Preprocessor1_Model047
48 rmse standard 4950.5952668 5 250.8397584 Preprocessor1_Model048
48 rsq standard 0.8431428 5 0.0163417 Preprocessor1_Model048
49 rmse standard 4971.3220377 5 252.7660356 Preprocessor1_Model049
49 rsq standard 0.8427938 5 0.0164434 Preprocessor1_Model049
50 rmse standard 4991.3165169 5 254.1472657 Preprocessor1_Model050
50 rsq standard 0.8424261 5 0.0164899 Preprocessor1_Model050
51 rmse standard 5010.3814183 5 258.9393462 Preprocessor1_Model051
51 rsq standard 0.8419906 5 0.0168380 Preprocessor1_Model051
52 rmse standard 5032.7688162 5 256.3764770 Preprocessor1_Model052
52 rsq standard 0.8419245 5 0.0168224 Preprocessor1_Model052
53 rmse standard 5056.2074325 5 255.6109658 Preprocessor1_Model053
53 rsq standard 0.8415965 5 0.0167222 Preprocessor1_Model053
54 rmse standard 5070.5846465 5 256.6988728 Preprocessor1_Model054
54 rsq standard 0.8417833 5 0.0168816 Preprocessor1_Model054
55 rmse standard 5092.3323325 5 254.2759554 Preprocessor1_Model055
55 rsq standard 0.8417379 5 0.0166438 Preprocessor1_Model055
56 rmse standard 5113.9907711 5 252.9157752 Preprocessor1_Model056
56 rsq standard 0.8418789 5 0.0168044 Preprocessor1_Model056
57 rmse standard 5141.3076935 5 253.8617620 Preprocessor1_Model057
57 rsq standard 0.8410924 5 0.0169582 Preprocessor1_Model057
58 rmse standard 5164.6103863 5 254.3423992 Preprocessor1_Model058
58 rsq standard 0.8409587 5 0.0171882 Preprocessor1_Model058
59 rmse standard 5187.1686629 5 255.3040686 Preprocessor1_Model059
59 rsq standard 0.8407301 5 0.0174055 Preprocessor1_Model059
60 rmse standard 5206.3796593 5 255.0693833 Preprocessor1_Model060
60 rsq standard 0.8407261 5 0.0173943 Preprocessor1_Model060
61 rmse standard 5226.1861062 5 251.1256514 Preprocessor1_Model061
61 rsq standard 0.8408629 5 0.0173581 Preprocessor1_Model061
62 rmse standard 5246.6497017 5 250.4438026 Preprocessor1_Model062
62 rsq standard 0.8408098 5 0.0172474 Preprocessor1_Model062
63 rmse standard 5260.6746375 5 253.8121305 Preprocessor1_Model063
63 rsq standard 0.8411199 5 0.0173674 Preprocessor1_Model063
64 rmse standard 5282.6943293 5 255.2813460 Preprocessor1_Model064
64 rsq standard 0.8409379 5 0.0173612 Preprocessor1_Model064
65 rmse standard 5310.0626579 5 255.2820449 Preprocessor1_Model065
65 rsq standard 0.8403366 5 0.0176917 Preprocessor1_Model065
66 rmse standard 5330.5619104 5 255.2167332 Preprocessor1_Model066
66 rsq standard 0.8401139 5 0.0178185 Preprocessor1_Model066
67 rmse standard 5345.0835302 5 254.7457515 Preprocessor1_Model067
67 rsq standard 0.8401665 5 0.0179959 Preprocessor1_Model067
68 rmse standard 5359.6814118 5 250.1877437 Preprocessor1_Model068
68 rsq standard 0.8405975 5 0.0177160 Preprocessor1_Model068
69 rmse standard 5390.7338072 5 248.9819478 Preprocessor1_Model069
69 rsq standard 0.8402191 5 0.0177109 Preprocessor1_Model069
70 rmse standard 5406.9539519 5 247.6082823 Preprocessor1_Model070
70 rsq standard 0.8405627 5 0.0176769 Preprocessor1_Model070
71 rmse standard 5424.0772491 5 249.3965793 Preprocessor1_Model071
71 rsq standard 0.8404457 5 0.0177430 Preprocessor1_Model071
72 rmse standard 5453.3601870 5 245.8968002 Preprocessor1_Model072
72 rsq standard 0.8397146 5 0.0176040 Preprocessor1_Model072
73 rmse standard 5474.6071400 5 243.8752587 Preprocessor1_Model073
73 rsq standard 0.8394817 5 0.0176707 Preprocessor1_Model073
74 rmse standard 5500.3324813 5 238.5758351 Preprocessor1_Model074
74 rsq standard 0.8389886 5 0.0176102 Preprocessor1_Model074
75 rmse standard 5533.2174181 5 235.5433973 Preprocessor1_Model075
75 rsq standard 0.8385261 5 0.0176846 Preprocessor1_Model075
76 rmse standard 5560.6640148 5 235.6440023 Preprocessor1_Model076
76 rsq standard 0.8381477 5 0.0179443 Preprocessor1_Model076
77 rmse standard 5584.7835166 5 236.0315921 Preprocessor1_Model077
77 rsq standard 0.8377637 5 0.0181491 Preprocessor1_Model077
78 rmse standard 5604.5919775 5 236.6764084 Preprocessor1_Model078
78 rsq standard 0.8377688 5 0.0182580 Preprocessor1_Model078
79 rmse standard 5633.3447960 5 235.5720509 Preprocessor1_Model079
79 rsq standard 0.8372383 5 0.0183068 Preprocessor1_Model079
80 rmse standard 5649.1854764 5 233.0229942 Preprocessor1_Model080
80 rsq standard 0.8376092 5 0.0183126 Preprocessor1_Model080
81 rmse standard 5669.7011509 5 227.1266987 Preprocessor1_Model081
81 rsq standard 0.8380915 5 0.0180356 Preprocessor1_Model081
82 rmse standard 5692.7484482 5 218.4712738 Preprocessor1_Model082
82 rsq standard 0.8383354 5 0.0177220 Preprocessor1_Model082
83 rmse standard 5715.0654889 5 217.8337115 Preprocessor1_Model083
83 rsq standard 0.8382711 5 0.0177845 Preprocessor1_Model083
84 rmse standard 5737.5170256 5 217.0799424 Preprocessor1_Model084
84 rsq standard 0.8381157 5 0.0177577 Preprocessor1_Model084
85 rmse standard 5765.2142956 5 215.0876605 Preprocessor1_Model085
85 rsq standard 0.8381152 5 0.0176802 Preprocessor1_Model085
86 rmse standard 5790.4235698 5 210.4478847 Preprocessor1_Model086
86 rsq standard 0.8379283 5 0.0174265 Preprocessor1_Model086
87 rmse standard 5820.4823630 5 207.4782758 Preprocessor1_Model087
87 rsq standard 0.8379163 5 0.0172870 Preprocessor1_Model087
88 rmse standard 5846.0684056 5 204.0553549 Preprocessor1_Model088
88 rsq standard 0.8375326 5 0.0172937 Preprocessor1_Model088
89 rmse standard 5872.1645360 5 198.7234937 Preprocessor1_Model089
89 rsq standard 0.8370917 5 0.0170569 Preprocessor1_Model089
90 rmse standard 5893.2743975 5 196.8409955 Preprocessor1_Model090
90 rsq standard 0.8370979 5 0.0169406 Preprocessor1_Model090
91 rmse standard 5921.1560325 5 195.0947100 Preprocessor1_Model091
91 rsq standard 0.8372950 5 0.0169937 Preprocessor1_Model091
92 rmse standard 5949.2206657 5 188.9127415 Preprocessor1_Model092
92 rsq standard 0.8372605 5 0.0168327 Preprocessor1_Model092
93 rmse standard 5974.8455124 5 185.1585418 Preprocessor1_Model093
93 rsq standard 0.8372469 5 0.0169116 Preprocessor1_Model093
94 rmse standard 6002.5265350 5 183.7416926 Preprocessor1_Model094
94 rsq standard 0.8369820 5 0.0168052 Preprocessor1_Model094
95 rmse standard 6033.7125940 5 182.7454155 Preprocessor1_Model095
95 rsq standard 0.8364537 5 0.0168431 Preprocessor1_Model095
96 rmse standard 6058.9580811 5 181.8945963 Preprocessor1_Model096
96 rsq standard 0.8368205 5 0.0168068 Preprocessor1_Model096
97 rmse standard 6089.5336705 5 181.2305913 Preprocessor1_Model097
97 rsq standard 0.8366151 5 0.0167848 Preprocessor1_Model097
98 rmse standard 6118.9537048 5 178.0149517 Preprocessor1_Model098
98 rsq standard 0.8367219 5 0.0166843 Preprocessor1_Model098
99 rmse standard 6148.5943339 5 176.6887710 Preprocessor1_Model099
99 rsq standard 0.8367294 5 0.0166779 Preprocessor1_Model099
100 rmse standard 6181.5358696 5 171.9513420 Preprocessor1_Model100
100 rsq standard 0.8367538 5 0.0165249 Preprocessor1_Model100

En comparant le RMSE pour les différents k,on peut constater que si on prend trop ou très peu de voisins, nous obtenons une faible précision.

ins_results %>%
  filter(.metric == "rmse") %>%
  ggplot(aes(neighbors, mean)) +
  geom_line(col="red") +
  geom_point(col = "#336980")+
  labs(title="Residuals vs number of neighbours")

Alors qu’on trouve que la valeur optimale est de considérer 19 neighbours, chose qui peut être interprété comme suit : pour prédire les charges d’un assuré, il vaut mieux voir les 19 plus proches observations des assurés pour les variables age + bmi + children + health + region.

ins_min <- ins_results %>%
  filter(.metric == "rmse") %>%
  filter(mean == min(mean))
ins_min %>% 
  knitr::kable(align = "c") %>% 
  kableExtra::kable_styling(full_width = F, position = "left")
neighbors .metric .estimator mean n std_err .config
21 rmse standard 4629.166 5 260.0895 Preprocessor1_Model021
kmin <- ins_min %>% pull(neighbors)

maintenant on va re-entrainer notre modèle avec la valeur optimale de k=21, et évaluer sa performance.

set.seed(123456)
ins_spec <-
  nearest_neighbor(weight_func = "rectangular", neighbors = kmin) %>%
  set_engine("kknn") %>%
  set_mode("regression")

ins_fit <- workflow() %>%
  add_recipe(ins_recipe) %>%
  add_model(ins_spec) %>%
  fit(data = insurance_train)
ins_fit
## == Workflow [trained] ==========================================================
## Preprocessor: Recipe
## Model: nearest_neighbor()
## 
## -- Preprocessor ----------------------------------------------------------------
## 2 Recipe Steps
## 
## * step_scale()
## * step_center()
## 
## -- Model -----------------------------------------------------------------------
## 
## Call:
## kknn::train.kknn(formula = ..y ~ ., data = data, ks = min_rows(21L,     data, 5), kernel = ~"rectangular")
## 
## Type of response variable: continuous
## minimal mean absolute error: 2614.916
## Minimal mean squared error: 21519995
## Best kernel: rectangular
## Best k: 21
ins_res <- ins_fit %>%
  predict(insurance_test) %>%
  bind_cols(insurance_test)

ins_res %>%  ggplot(aes(y = charges, x = .pred, col = health)) +
  geom_point() +
  scale_color_manual(values = c("#79a346", "#245373","#964a74")) +
  geom_abline(slope = 1)

ins_summary <- ins_res %>%
  metrics(truth = charges, estimate = .pred)

ins_summary %>% 
  knitr::kable(align = "c") %>% 
  kableExtra::kable_styling(full_width = F, position = "left")
.metric .estimator .estimate
rmse standard 4642.8788742
rsq standard 0.8698845
mae standard 2675.5606994

Random forest

Modèl initial

On va commencer par entrainer un modèle de “random forest” sur le training-set en utilisant l’ensemble des variables et en gardant les paramètres par défaut qu’on va évaluer sur le test-set.

library(randomForest)
model_rand_forest <- rand_forest(mode = "regression") %>%
  set_engine("randomForest") %>%
  fit(charges ~ age + bmi + children + health +
        region + sex, data = insurance_train)
model_rand_forest_pred <- model_rand_forest %>%
  predict(new_data = insurance_test) %>%
  bind_cols(insurance_test %>% dplyr::select(charges))
model_rand_forest_pred %>% 
  knitr::kable(align = "c") %>% 
  kableExtra::kable_styling(full_width = F, position = "left")  %>% 
  # kableExtra::kable_paper() %>%
  kableExtra::scroll_box(width = "230px", height = "300px")
.pred charges
10451.433 6406.411
3660.008 2198.190
6335.282 4687.797
2650.023 1625.434
5255.351 3046.062
5914.029 4949.759
7127.264 6313.759
8745.097 6079.672
24201.908 23568.272
37416.161 37742.576
45024.699 47496.494
8077.114 5989.524
3122.390 1743.214
7023.986 5920.104
19461.594 16577.780
3258.499 1532.470
22657.791 21098.554
12131.086 8026.667
17818.219 15820.699
6291.896 5003.853
6277.419 4646.759
12104.121 11488.317
2753.991 1705.624
4696.097 3385.399
16312.455 32734.186
8002.223 6082.405
13703.279 12815.445
3901.848 2457.211
3712.193 1842.519
21616.845 19964.746
10121.862 6948.701
6680.250 5152.134
12675.568 10407.086
11412.459 8116.680
6258.789 4005.423
9949.955 7419.478
39358.065 43753.337
5224.614 4883.866
3559.738 1639.563
3799.573 2130.676
36462.321 37133.898
11101.341 7147.105
5816.038 1980.070
9060.899 8520.026
8562.868 7371.772
5973.723 2483.736
5792.565 5253.524
23597.827 19515.542
5301.054 2689.495
11767.640 24227.337
7042.353 6710.192
18616.737 19444.266
9463.841 7152.671
2796.933 1832.094
41704.425 41097.162
14748.181 13047.332
33955.399 33750.292
13813.137 20462.998
42807.802 46151.124
14392.991 14590.632
11574.902 9282.481
11499.566 9617.662
13989.587 12928.791
44759.885 48549.178
6240.538 4237.127
11742.170 9625.920
13794.900 9432.925
5927.263 3172.018
38474.315 38746.355
10729.208 9249.495
5917.838 20177.671
6158.480 4151.029
11956.411 8444.474
10554.241 8835.265
11062.991 7421.195
6078.886 4894.753
43443.660 47928.030
16418.889 13937.666
16314.367 13217.094
14569.062 13981.850
5953.594 3554.203
2906.175 14133.038
11649.409 10043.249
7068.275 3180.510
8246.913 3481.868
15812.307 16455.708
42067.931 42303.692
6874.941 5846.918
11622.378 8302.536
4344.020 1261.859
9964.540 9264.797
21168.982 19594.810
5642.873 2727.395
10590.816 8968.330
10943.152 9788.866
6937.360 18804.752
7527.599 5969.723
4516.087 2254.797
6912.075 5926.846
36192.333 37079.372
3844.132 1149.396
16017.208 12731.000
13062.726 11454.022
4284.787 2497.038
11150.924 9563.029
5604.969 4347.023
12271.593 12475.351
42997.411 48885.136
11686.132 11455.280
3781.670 27724.289
10989.137 8413.463
12825.503 9861.025
6787.646 5972.378
7088.685 6196.448
14239.052 13887.204
14338.218 10231.500
6978.112 3268.847
45184.647 45863.205
6187.720 3935.180
3648.610 1646.430
10118.581 9058.730
5073.998 2128.431
8209.503 7256.723
3308.223 1665.000
6452.114 3861.210
34883.391 43943.876
14354.402 13635.638
8397.619 7640.309
8291.151 5594.846
2957.054 1633.044
12082.887 10713.644
11204.735 9182.170
6266.344 11326.715
10128.842 9391.346
14202.350 14410.932
3844.630 2709.112
20321.587 20149.323
22536.564 32787.459
4402.245 1712.227
12567.328 12430.953
22366.704 24667.419
5513.127 3410.324
13856.903 12363.547
11511.613 10156.783
12414.191 29186.482
39494.548 40273.645
12416.558 10976.246
11902.851 9541.696
6760.476 5375.038
8204.819 6113.231
5890.396 5469.007
43028.597 43254.418
19710.618 19539.243
4854.934 2416.955
13832.884 14319.031
7137.109 6933.242
12375.762 9869.810
23249.079 24520.264
11897.721 11848.141
7031.613 28476.735
9815.333 9414.920
5687.986 2842.761
36363.311 55135.402
8414.867 7445.918
3901.651 1621.883
6525.474 4719.737
5189.801 1526.312
13131.374 12323.936
35773.630 36021.011
4439.393 2974.126
12839.126 10601.632
3055.762 1141.445
10602.393 8457.818
5161.794 3392.365
12671.686 26140.360
4676.341 2789.057
8156.097 4877.981
19094.409 19719.695
6167.933 5272.176
13393.098 13063.883
6142.009 4661.286
13871.941 14382.709
5736.741 2473.334
6369.748 5245.227
13169.838 13462.520
12201.263 25333.333
6609.788 2913.569
7975.012 6289.755
12478.543 10096.970
9088.949 7348.142
14399.332 9487.644
38840.379 39047.285
37484.189 38998.546
4215.930 12609.887
10440.395 9500.573
18650.951 19199.944
7654.150 4915.060
4152.923 3378.910
6302.337 5484.467
16418.292 16420.495
8125.900 7986.475
10332.652 8627.541
7084.050 4438.263
6529.118 3987.926
14316.166 12495.291
3889.998 1711.027
3338.757 2020.177
44657.913 44423.803
15168.442 13747.872
6549.855 22493.660
40960.798 39727.614
9363.662 17929.303
4585.558 2480.979
46105.477 48970.248
10449.176 8978.185
13405.882 13204.286
15911.262 15161.534
12464.283 11353.228
12260.997 9748.911
5869.858 20420.605
11022.017 8988.159
13679.689 10493.946
39453.551 38282.749
34293.996 34166.273
16221.916 14254.608
11720.541 9991.038
12215.450 11085.587
9545.484 7623.518
5902.753 2203.736
40335.488 40941.285
5958.347 3989.841
9031.146 7727.253
12087.194 10982.501
5837.904 2899.489
19191.890 19350.369
10388.267 7650.774
5535.421 2850.684
6377.622 2632.992
11864.433 9447.382
13020.834 36910.608
37660.758 38415.474
19573.251 20296.863
4350.166 12890.058
39993.455 41661.602
7401.246 7537.164
8968.118 7162.012
7544.543 6985.507
46090.708 47269.854
2862.347 1633.962
38734.561 37607.528
15476.619 30063.581
8276.101 7337.748
34502.953 34254.053
4922.545 3292.530
11391.260 10959.330
22680.237 24535.699
44181.887 47403.880
12158.897 8534.672
12102.118 11931.125
43595.950 46718.163
7686.517 2464.619
7341.964 7201.701
8920.595 22395.744
12315.288 10325.206
modelPerform_rand_forest = data.frame(
  RMSE = RMSE(
    model_rand_forest_pred$.pred,
    model_rand_forest_pred$charges
  ),
  R2 = R2(
    model_rand_forest_pred$.pred,
    model_rand_forest_pred$charges
  )
)
modelPerform_rand_forest %>% 
  knitr::kable(align = "c") %>% 
  kableExtra::kable_styling(full_width = F, position = "left")
RMSE R2
4720.893 0.8689894

Ce modèle a réussi à expliquer 86,81% de la variance totale avec un RMSE de 4726.053.

Amélioration du modèle

On va essayer maintenant d’améliorer ces performances en réglant les hyperparamètres min_n(nombre minimale d’observations requis dans un noeud pour le diviser) et mtry(Le nombre de prédicteurs qui seront échantillonnés au hasard à chaque division lors de la création des modèles d’arbre).

Pour automatiser les prétraitements, on va utiliser une recipe pour centrer et réduire les variables numeriques age et bmi, et la combiner avec la spécification du modèle dans un workflow.

#creating recipe
ins_recipe2 <-
  recipe(charges ~ age + bmi + children + health + region + sex, data = insurance_train) %>%
  step_scale(age, bmi) %>%
  step_center(age, bmi)
#model specification
ins_spec_rf <- rand_forest(min_n = tune(),mtry=tune()) %>%
  set_engine("randomForest") %>%
  set_mode("regression")
#workflow(model+recipe)
ins_wf_rf <- workflow() %>%
  add_recipe(ins_recipe2) %>%
  add_model(ins_spec_rf)
ins_wf_rf
## == Workflow ====================================================================
## Preprocessor: Recipe
## Model: rand_forest()
## 
## -- Preprocessor ----------------------------------------------------------------
## 2 Recipe Steps
## 
## * step_scale()
## * step_center()
## 
## -- Model -----------------------------------------------------------------------
## Random Forest Model Specification (regression)
## 
## Main Arguments:
##   mtry = tune()
##   min_n = tune()
## 
## Computational engine: randomForest

Tuning des hyperparamètres :
On va utiliser la méthode “Grid search” Pour pouvoir trouver la combinaison optimale des hyperparamètres “mtry” et “min_n”.

set.seed(123456)

ins_results_rf_grid <- ins_wf_rf %>%
  tune_grid(resamples = ins_vfold, grid = data_grid) 

ins_results_rf <- ins_results_rf_grid%>%
  collect_metrics()

ins_results_rf %>% 
  knitr::kable(align = "c") %>% 
  kableExtra::kable_styling(full_width = T, position = "left")  %>% 
  kableExtra::kable_paper() %>%
  kableExtra::scroll_box( height = "300px")
mtry min_n .metric .estimator mean n std_err .config
3 26 rmse standard 4502.9713269 5 246.9508460 Preprocessor1_Model01
3 26 rsq standard 0.8572369 5 0.0178017 Preprocessor1_Model01
4 26 rmse standard 4486.3605136 5 247.3486272 Preprocessor1_Model02
4 26 rsq standard 0.8575817 5 0.0176938 Preprocessor1_Model02
5 26 rmse standard 4507.6436824 5 241.7663073 Preprocessor1_Model03
5 26 rsq standard 0.8565050 5 0.0174662 Preprocessor1_Model03
6 26 rmse standard 4540.8392467 5 238.8689672 Preprocessor1_Model04
6 26 rsq standard 0.8547268 5 0.0173503 Preprocessor1_Model04
3 28 rmse standard 4492.4823224 5 246.0261177 Preprocessor1_Model05
3 28 rsq standard 0.8579409 5 0.0174977 Preprocessor1_Model05
4 28 rmse standard 4480.4196667 5 248.8713755 Preprocessor1_Model06
4 28 rsq standard 0.8579075 5 0.0178142 Preprocessor1_Model06
5 28 rmse standard 4506.2273528 5 242.0233731 Preprocessor1_Model07
5 28 rsq standard 0.8565265 5 0.0174745 Preprocessor1_Model07
6 28 rmse standard 4531.7370785 5 240.9903022 Preprocessor1_Model08
6 28 rsq standard 0.8551576 5 0.0174482 Preprocessor1_Model08
3 30 rmse standard 4502.4015236 5 253.0207005 Preprocessor1_Model09
3 30 rsq standard 0.8572583 5 0.0180438 Preprocessor1_Model09
4 30 rmse standard 4480.9065073 5 252.5024903 Preprocessor1_Model10
4 30 rsq standard 0.8577283 5 0.0179642 Preprocessor1_Model10
5 30 rmse standard 4504.6791618 5 247.8567109 Preprocessor1_Model11
5 30 rsq standard 0.8565087 5 0.0178983 Preprocessor1_Model11
6 30 rmse standard 4531.1114752 5 241.6965059 Preprocessor1_Model12
6 30 rsq standard 0.8551242 5 0.0174564 Preprocessor1_Model12
3 32 rmse standard 4499.8684604 5 249.1454239 Preprocessor1_Model13
3 32 rsq standard 0.8574502 5 0.0178746 Preprocessor1_Model13
4 32 rmse standard 4485.3582706 5 254.1382742 Preprocessor1_Model14
4 32 rsq standard 0.8573791 5 0.0181175 Preprocessor1_Model14
5 32 rmse standard 4499.0556140 5 249.3251531 Preprocessor1_Model15
5 32 rsq standard 0.8567260 5 0.0178276 Preprocessor1_Model15
6 32 rmse standard 4517.2728936 5 246.0942965 Preprocessor1_Model16
6 32 rsq standard 0.8558118 5 0.0176314 Preprocessor1_Model16
3 34 rmse standard 4494.1781361 5 247.5034820 Preprocessor1_Model17
3 34 rsq standard 0.8579955 5 0.0177320 Preprocessor1_Model17
4 34 rmse standard 4477.2679787 5 252.1122315 Preprocessor1_Model18
4 34 rsq standard 0.8578414 5 0.0179604 Preprocessor1_Model18
5 34 rmse standard 4498.0367395 5 255.1573410 Preprocessor1_Model19
5 34 rsq standard 0.8565994 5 0.0183137 Preprocessor1_Model19
6 34 rmse standard 4519.5092097 5 249.7806274 Preprocessor1_Model20
6 34 rsq standard 0.8555612 5 0.0179683 Preprocessor1_Model20
3 36 rmse standard 4498.9937117 5 253.1112587 Preprocessor1_Model21
3 36 rsq standard 0.8575007 5 0.0180595 Preprocessor1_Model21
4 36 rmse standard 4482.6950210 5 257.3142287 Preprocessor1_Model22
4 36 rsq standard 0.8575048 5 0.0183140 Preprocessor1_Model22
5 36 rmse standard 4493.8055794 5 253.7648297 Preprocessor1_Model23
5 36 rsq standard 0.8568971 5 0.0181862 Preprocessor1_Model23
6 36 rmse standard 4519.1213683 5 250.4609414 Preprocessor1_Model24
6 36 rsq standard 0.8555345 5 0.0180142 Preprocessor1_Model24
3 38 rmse standard 4518.8059165 5 249.9243199 Preprocessor1_Model25
3 38 rsq standard 0.8566750 5 0.0179041 Preprocessor1_Model25
4 38 rmse standard 4486.1647083 5 257.0999480 Preprocessor1_Model26
4 38 rsq standard 0.8571813 5 0.0182885 Preprocessor1_Model26
5 38 rmse standard 4498.2283354 5 252.1359709 Preprocessor1_Model27
5 38 rsq standard 0.8564944 5 0.0180877 Preprocessor1_Model27
6 38 rmse standard 4517.3020586 5 253.0717244 Preprocessor1_Model28
6 38 rsq standard 0.8554945 5 0.0182853 Preprocessor1_Model28

Visualisation de l’erreur pour les combinaisons de “mtry” et “min_n” :

ins_results_rf %>%
  filter(.metric == "rmse") %>%
  mutate(min_n = factor(min_n)) %>%
  ggplot(aes(x=mtry, y=mean,col=min_n)) +
  geom_line(alpha = 0.5, size = 1.5) +
  geom_point() +
  labs(title="RMSE for each combinaison of mtry & min_n",y = "rmse")+
  theme(plot.title = element_text(hjust = 0.5))

finalement, on va choisir les meilleurs hyper-paramètres pour notre modèle, puis mettre à jour la spécification du modèle d’origine pour créer la spécification de modèle finale.

best_param <- ins_results_rf %>%
  filter(.metric == "rmse") %>%
  filter(mean == min(mean))
best_param %>% 
  knitr::kable(align = "c") %>% 
  kableExtra::kable_styling(full_width = F, position = "left")
mtry min_n .metric .estimator mean n std_err .config
4 34 rmse standard 4477.268 5 252.1122 Preprocessor1_Model18
best_mtry <- best_param %>% pull(mtry)
best_min_n <- best_param %>% pull(min_n)
best_rmse <- select_best(ins_results_rf_grid,"rmse")
final_rf <- finalize_model(
  ins_spec_rf,
  best_rmse
)
final_rf
## Random Forest Model Specification (regression)
## 
## Main Arguments:
##   mtry = 4
##   min_n = 34
## 
## Computational engine: randomForest

Les hyper-paramètres optimaux sont mtry = 4 et min_n = 34 .

last_fit(final_rf, charges~age + bmi + children + health + region + sex, insurance_data_split) %>% 
  collect_metrics() %>% 
  knitr::kable(align = "c") %>% 
  kableExtra::kable_styling(full_width = F, position = "left")
.metric .estimator .estimate .config
rmse standard 4596.2229943 Preprocessor1_Model1
rsq standard 0.8724205 Preprocessor1_Model1

L’optimisation des hyperparamètres a amélioré les preformance,

model_final_rf <- rand_forest(mode = "regression",mtry=best_mtry,min_n = best_min_n) %>%
  set_engine("randomForest") %>%
  fit(charges ~ age + bmi + children + health +region + sex, data = insurance_train)
model_final_rf_pred <- model_final_rf %>%
  predict(new_data = insurance_test) %>%
  bind_cols(insurance_test %>% dplyr::select(charges))
model_final_rf_pred %>% 
  knitr::kable(align = "c") %>% 
  kableExtra::kable_styling(full_width = F, position = "left")  %>% 
  # kableExtra::kable_paper() %>%
  kableExtra::scroll_box(width = "230px", height = "300px")
.pred charges
9313.516 6406.411
2825.881 2198.190
6488.111 4687.797
2035.577 1625.434
5258.149 3046.062
5496.489 4949.759
7100.562 6313.759
8053.543 6079.672
25660.807 23568.272
38533.203 37742.576
46657.647 47496.494
7361.800 5989.524
2720.721 1743.214
7055.363 5920.104
18660.887 16577.780
2589.512 1532.470
23737.523 21098.554
11330.600 8026.667
17857.613 15820.699
5998.075 5003.853
6289.880 4646.759
13072.438 11488.317
2343.480 1705.624
4775.943 3385.399
16957.235 32734.186
7062.731 6082.405
13898.893 12815.445
3267.267 2457.211
3034.871 1842.519
23232.219 19964.746
9521.560 6948.701
6228.107 5152.134
12215.901 10407.086
10353.571 8116.680
6263.941 4005.423
9750.669 7419.478
41995.343 43753.337
5326.750 4883.866
2948.682 1639.563
3002.279 2130.676
37745.972 37133.898
10158.129 7147.105
3892.773 1980.070
8940.956 8520.026
7633.471 7371.772
4898.152 2483.736
5305.123 5253.524
23834.665 19515.542
4761.666 2689.495
12532.750 24227.337
7516.443 6710.192
18372.217 19444.266
8560.121 7152.671
2633.732 1832.094
43264.498 41097.162
14419.087 13047.332
35894.156 33750.292
13747.911 20462.998
44097.105 46151.124
14596.187 14590.632
11087.346 9282.481
12631.051 9617.662
14160.961 12928.791
46682.437 48549.178
6449.327 4237.127
12739.273 9625.920
12256.187 9432.925
5383.750 3172.018
39760.661 38746.355
11286.481 9249.495
5172.913 20177.671
6106.909 4151.029
11648.594 8444.474
10782.399 8835.265
10904.313 7421.195
5835.489 4894.753
46091.216 47928.030
17849.756 13937.666
15447.220 13217.094
15185.567 13981.850
4889.230 3554.203
2385.165 14133.038
11791.167 10043.249
5976.146 3180.510
7568.008 3481.868
16560.577 16455.708
45292.592 42303.692
6790.604 5846.918
10633.440 8302.536
2933.802 1261.859
10725.511 9264.797
22349.628 19594.810
4968.008 2727.395
10294.723 8968.330
11107.695 9788.866
5842.575 18804.752
6509.857 5969.723
3816.306 2254.797
6324.843 5926.846
37967.233 37079.372
2151.573 1149.396
14487.568 12731.000
12908.331 11454.022
3870.831 2497.038
10479.305 9563.029
5726.918 4347.023
13383.609 12475.351
43341.317 48885.136
11880.460 11455.280
2534.345 27724.289
9994.734 8413.463
11822.284 9861.025
6356.439 5972.378
6746.965 6196.448
14780.395 13887.204
12753.393 10231.500
6456.774 3268.847
45568.914 45863.205
5887.979 3935.180
2330.390 1646.430
10068.375 9058.730
3264.518 2128.431
7810.201 7256.723
2340.970 1665.000
6346.507 3861.210
39966.216 43943.876
14493.389 13635.638
7954.188 7640.309
8104.312 5594.846
2329.208 1633.044
12894.770 10713.644
11145.111 9182.170
5935.874 11326.715
10164.694 9391.346
14931.465 14410.932
3039.612 2709.112
20035.035 20149.323
24892.767 32787.459
3092.736 1712.227
12485.291 12430.953
23750.248 24667.419
4854.198 3410.324
13271.241 12363.547
11348.461 10156.783
12366.547 29186.482
40995.523 40273.645
12066.790 10976.246
12169.318 9541.696
6388.622 5375.038
7449.036 6113.231
5754.314 5469.007
45830.372 43254.418
19481.902 19539.243
3910.800 2416.955
14284.351 14319.031
7834.577 6933.242
12127.417 9869.810
25600.279 24520.264
12104.214 11848.141
6936.738 28476.735
9464.778 9414.920
5453.816 2842.761
39116.387 55135.402
8417.074 7445.918
3605.787 1621.883
6036.647 4719.737
4322.806 1526.312
13924.063 12323.936
36866.129 36021.011
4381.692 2974.126
12360.591 10601.632
1920.126 1141.445
10344.478 8457.818
4675.338 3392.365
11313.690 26140.360
4349.087 2789.057
8661.721 4877.981
20231.658 19719.695
5488.869 5272.176
14164.942 13063.883
5870.150 4661.286
14895.313 14382.709
5337.819 2473.334
6009.825 5245.227
13913.109 13462.520
11816.011 25333.333
6098.964 2913.569
7182.002 6289.755
11255.735 10096.970
9151.591 7348.142
13473.497 9487.644
39884.879 39047.285
45186.663 38998.546
4094.535 12609.887
10409.994 9500.573
18814.759 19199.944
7492.538 4915.060
4670.178 3378.910
6202.462 5484.467
16693.063 16420.495
7998.246 7986.475
10465.608 8627.541
7905.082 4438.263
5186.196 3987.926
13589.414 12495.291
2744.719 1711.027
3294.056 2020.177
46280.671 44423.803
16365.670 13747.872
7055.858 22493.660
42875.918 39727.614
8743.255 17929.303
4749.798 2480.979
46856.250 48970.248
10390.913 8978.185
14084.319 13204.286
16488.154 15161.534
12500.187 11353.228
13621.118 9748.911
6285.489 20420.605
10942.061 8988.159
13411.842 10493.946
40974.398 38282.749
35832.055 34166.273
15699.971 14254.608
12546.335 9991.038
12556.381 11085.587
9282.351 7623.518
4376.375 2203.736
41636.720 40941.285
5638.113 3989.841
9254.732 7727.253
12632.213 10982.501
5414.342 2899.489
20258.507 19350.369
8847.998 7650.774
5669.170 2850.684
5153.549 2632.992
12303.216 9447.382
13712.104 36910.608
39138.877 38415.474
19613.455 20296.863
3115.142 12890.058
42625.993 41661.602
6959.332 7537.164
8082.499 7162.012
7950.227 6985.507
46357.351 47269.854
2329.012 1633.962
39895.039 37607.528
14673.729 30063.581
8582.364 7337.748
35842.905 34254.053
4575.516 3292.530
12170.936 10959.330
25046.491 24535.699
46576.316 47403.880
12017.221 8534.672
13006.688 11931.125
45957.826 46718.163
6179.391 2464.619
7514.225 7201.701
8435.519 22395.744
13066.856 10325.206
modelPerform_model_final = data.frame(
  RMSE = RMSE(
    model_final_rf_pred$.pred,
    model_final_rf_pred$charges
  ),
  R2 = R2(
    model_final_rf_pred$.pred,
    model_final_rf_pred$charges
  )
)
modelPerform_model_final %>% 
  knitr::kable(align = "c") %>% 
  kableExtra::kable_styling(full_width = F, position = "left")
RMSE R2
4607.556 0.8717825
model_final_rf_pred %>%  ggplot(aes(y = charges, x = .pred)) +
  geom_point(col="#336980",size=2,alpha=0.5) +
  geom_abline(slope = 1,col="#821f2e")

Comparaison des modèles

result_linear_mod = predict(mod.extend, insurance_data)
lm_rmse = RMSE(result_linear_mod, insurance_data$charges)
lm_rsq = perf %>% pull(r.squared)
knn_rmse = ins_summary %>% filter(.metric == "rmse") %>% pull(.estimate)
knn_rsq = ins_summary %>% filter(.metric == "rsq") %>% pull(.estimate)
rf_rsq = modelPerform_model_final %>% pull(R2)
rf_rmse = modelPerform_model_final %>% pull(RMSE)

model <- c("Linear model", "KNN", "Random forest")
R_squared <- c(lm_rsq, knn_rsq, rf_rsq)
RMSE <- c(lm_rmse, knn_rmse, rf_rmse)

data.frame(model, R_squared, RMSE) %>% 
  knitr::kable(align = "c") %>% 
  kableExtra::kable_styling(full_width = F, position = "left")
model R_squared RMSE
Linear model 0.8721790 4327.960
KNN 0.8698845 4642.879
Random forest 0.8717825 4607.556

Conclusion

Les trois modèles permettent d’atteindre presque la même précision, mais on a un grand avantage de l’interprétabilité dans le modèle linéaire.
Pour obtenir encore plus de précision dans ses prédiction, on propose à cette assurance de collecter plus de données sur ses clients afin d’expliquer le comportement de certain individus qu’on a remarquer -dans la phase de visualisation- qu’ils ne suivent pas la même tendance des autres.
Si, même après l’étude, on n’a pas arriver à distinguer ces observations du reste de la population on doit pensée à utiliser des modéles dédiés à ce problème, principalement RLM Robust Linear model qui donne un poid faibles aux points influenceurs.