kilometer’s

a junk space

[研究関連]  [R関連] [書評] [その他]

home


目標

準備

pipenv shell
pipenv install pandas
pipenv install numpy
pipenv install umap-learn

umap.plotを使いたければ追加で下記もインストール

pipenv install scikit-learn
pipenv install datashader
pipenv install bokeh
pipenv install holoviews
pipenv install matplotlib
pipenv install scikit-image
pipenv install colorcet
library(reticulate)
venv <- "*/bin/python"
use_python(python = venv, required = TRUE)
import pandas as pd
import numpy as np
import umap
import umap.plot
library(tidyverse)
library(uwot)

データ

dat <- 
  palmerpenguins::penguins %>% 
  na.omit()
dat
## # A tibble: 333 x 8
##    species island bill_length_mm bill_depth_mm flipper_length_… body_mass_g
##    <fct>   <fct>           <dbl>         <dbl>            <int>       <int>
##  1 Adelie  Torge…           39.1          18.7              181        3750
##  2 Adelie  Torge…           39.5          17.4              186        3800
##  3 Adelie  Torge…           40.3          18                195        3250
##  4 Adelie  Torge…           36.7          19.3              193        3450
##  5 Adelie  Torge…           39.3          20.6              190        3650
##  6 Adelie  Torge…           38.9          17.8              181        3625
##  7 Adelie  Torge…           39.2          19.6              195        4675
##  8 Adelie  Torge…           41.1          17.6              182        3200
##  9 Adelie  Torge…           38.6          21.2              191        3800
## 10 Adelie  Torge…           34.6          21.1              198        4400
## # … with 323 more rows, and 2 more variables: sex <fct>, year <int>

データ加工

dat_for_umap <-
  dat %>% 
  select(c(ends_with("_mm"), ends_with("_g"))) %>% 
  mutate_all(~{(. - mean(.))/sd(.)})
dat_for_umap
## # A tibble: 333 x 4
##    bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
##             <dbl>         <dbl>             <dbl>       <dbl>
##  1         -0.895         0.780            -1.42       -0.568
##  2         -0.822         0.119            -1.07       -0.506
##  3         -0.675         0.424            -0.426      -1.19 
##  4         -1.33          1.08             -0.568      -0.940
##  5         -0.858         1.74             -0.782      -0.692
##  6         -0.931         0.323            -1.42       -0.723
##  7         -0.876         1.24             -0.426       0.581
##  8         -0.529         0.221            -1.35       -1.25 
##  9         -0.986         2.05             -0.711      -0.506
## 10         -1.72          2.00             -0.212       0.240
## # … with 323 more rows

R側でumapを実行

result_umap_r <-
  dat_for_umap %>% 
  umap(n_neighbors = 15,
       n_components = 2) %>%  # 引数はとりあえず初期設定のまま
  as_tibble()
result_umap_r
## # A tibble: 333 x 2
##       V1     V2
##    <dbl>  <dbl>
##  1 -7.21 -0.110
##  2 -8.04  0.546
##  3 -8.02  1.29 
##  4 -5.88  0.156
##  5 -5.42 -0.836
##  6 -7.66  0.337
##  7 -5.75 -2.07 
##  8 -8.27  0.508
##  9 -5.33 -0.888
## 10 -5.36 -0.925
## # … with 323 more rows
dat_g <-
  bind_cols(dat, result_umap_r)
ggplot(data = dat_g) +
  aes(V1, V2, color = species) +
  geom_point(size = 0.75, alpha = 0.5)

はい、きれい。

Pythonでumapを実行

dat = r.dat_for_umap
dat
##      bill_length_mm  bill_depth_mm  flipper_length_mm  body_mass_g
## 0         -0.894695       0.779559          -1.424608    -0.567621
## 1         -0.821552       0.119404          -1.067867    -0.505525
## 2         -0.675264       0.424091          -0.425733    -1.188572
## 3         -1.333559       1.084246          -0.568429    -0.940192
## 4         -0.858123       1.744400          -0.782474    -0.691811
## ..              ...            ...                ...          ...
## 328        2.159064       1.338151           0.430446    -0.257145
## 329       -0.090112       0.474872           0.073705    -1.002287
## 330        1.025333       0.525653          -0.568429    -0.536573
## 331        1.244765       0.931902           0.644491    -0.132954
## 332        1.135049       0.779559          -0.211688    -0.536573
## 
## [333 rows x 4 columns]
model_umap = umap.UMAP()
model_umap.fit(dat)
## UMAP(tqdm_kwds={'bar_format': '{desc}: {percentage:3.0f}%| {bar} {n_fmt}/{total_fmt} [{elapsed}]', 'desc': 'Epochs completed', 'disable': True})
umap.plot.points(model_umap)

いい感じ。

dat_umap_trans = model_umap.fit_transform(dat)
dat_umap_trans[0:5]
## array([[ 9.261274  ,  1.9660232 ],
##        [ 8.295108  ,  2.5176768 ],
##        [ 7.568429  ,  2.4515383 ],
##        [ 9.316274  ,  0.66552997],
##        [10.612278  ,  0.2595372 ]], dtype=float32)
dat_umap_trans.shape
## (333, 2)
py <- import_main()

result_umap_py <- 
  py$dat_umap_trans %>% 
  as_tibble()

dat %>% 
  bind_cols(result_umap_py) %>% 
  ggplot() +
  aes(x = V1, y = V2, color = species) +
  geom_point(size = 0.75, alpha = 0.5)

逆変換を試す

dat_for_inv = dat_umap_trans
dat_pred = model_umap.inverse_transform(dat_for_inv)
dat_pred[0:5]
## array([[-0.9026765 ,  0.8777814 , -1.2740825 , -0.8012936 ],
##        [-0.81351095,  0.18741204, -1.0375887 , -0.6418856 ],
##        [-0.80961215,  0.22537063, -0.5507034 , -1.1476117 ],
##        [-1.1665847 ,  0.9870761 , -0.4798257 , -0.6067533 ],
##        [-0.7661266 ,  1.7842672 , -0.6767104 , -0.5975826 ]],
##       dtype=float32)

確かに近い値になっている。

py <- import_main()

dat_pred <- 
  py$dat_pred %>% 
  as_tibble()

dat_g <-
  dat_for_umap %>% 
  rename_all(~str_c("x", 1:4)) %>% 
  bind_cols(dat_pred) %>% 
  pivot_longer(cols = everything(),
               names_to = c(".value", "tag"),
               names_pattern = "(.)(.)") %>% 
  mutate(tag = factor(tag))

levels(dat_g$tag) <- 
  dat_for_umap %>% names() # ここは多分forcatsパッケージの出番
ggplot(data = dat_g) +
  aes(x, V) +
  geom_path(data = data.frame(x = c(-2, 2), V = c(-2, 2)),
            color = "blue")+
  geom_point(size = 0.2, alpha = 0.5) +
  facet_wrap(~tag, scales = "free", nrow = 1) +
  ylab("inv. trans. score")

dat_g %>% 
  mutate(d = V - x) %>% 
  group_by(tag) %>% 
  summarise(mean = mean(d),
            sd = sd(d))
## # A tibble: 4 x 3
##   tag                   mean    sd
##   <fct>                <dbl> <dbl>
## 1 bill_length_mm    -0.00287 0.185
## 2 bill_depth_mm     -0.00186 0.187
## 3 flipper_length_mm  0.00108 0.207
## 4 body_mass_g       -0.0100  0.221


2022年2月19日