suzuzusu日記

(´・ω・`)

尤度関数におけるガウス分布とスチューデントのt分布の比較

RStanを使って尤度関数がガウス分布と場合とスチューデントのt分布の場合でどのように異なるのかを比較します. 最終的にガウス分布に比べてスチューデントのt分布が外れ値に頑健な回帰をすることが可能であることを確認します.

線形回帰

今回は以下のパラメータ\alpha,\betaの単純な線形回帰モデルを使用します.

f(x)=\alpha+\beta x

近似させた \widehat{y} \approx f(x) に対する尤度関数を教師データy_tとしてガウス分布の場合は\mathcal{N}(y_t|\widehat{y}, \sigma),スチューデントのt分布の場合は\mathcal{St}(y_t|\widehat{y}, \nu, \sigma)とします.

データ

求めたい関数を以下に定義して,ノイズと外れ値を含むデータを作成します.

f(x)=100 + 2x

f:id:suzuzusu:20191117151710p:plain
サンプルデータ

上記の画像のようなデータを使用します.

結果

RStanを使用してMCMCサンプリングをしてパラメータ\alpha,\betaの事後分布の平均を使用すると以下のような回帰曲線が得られます.

y_{gaussian} = 295.9909 + 1.904833 x

y_{st} = 103.6868 + 1.985594 x

f:id:suzuzusu:20191117151736p:plain

ガウス分布では外れ値に引っ張られた回帰をしているのに対して,スチューデントのt分布の方が頑健な回帰をしていることが分かります.

実装

library(rstan)
options(mc.cores = parallel::detectCores())
rstan_options(auto_write = TRUE)
library(ggplot2)

# データ作成
set.seed(1)
N <- 100
x = sort(sample(1:1000, N))
alpha = 100
beta = 2
y = alpha + beta*x
y = y + rnorm(N, sd=30)
ris = sample(1:100, 30)
for (ri in ris) {
  # 外れ値
  y[ri] <- y[ri] + runif(1, 0, 1000)
}


dat <- list(N = N, x = x, y = y)
# 尤度関数がスチューデントのt分布
fit_t <- stan(file = 'student_t.stan', data = dat)
# 尤度関数がガウス分布
fit_g <- stan(file = 'gaussian.stan', data = dat)

x_tst = c(1:1000)

# パラメータの平均値を取得
m <- get_posterior_mean(fit_g)
alpha_g = m[1]
beta_g = m[2]
sigma_g = m[3]
y_mean_g = beta_g*x_tst + alpha_g

m <- get_posterior_mean(fit_t)
alpha_t = m[1]
beta_t = m[2]
sigma_t = m[3]
nu_t = m[4]
y_mean_t = beta_t*x_tst + alpha_t


p <- data.frame(x = x,
                y = y,
                x_tst = x_tst,
                y_mean_g = y_mean_g,
                y_mean_t = y_mean_t)

g <- ggplot(data=p) +
  geom_point(aes(x=x, y=y, colour="")) +
  labs(colour="sample")
plot(g)
ggsave(file = "sample.png", plot = g)


g <- ggplot(data=p) +
  geom_point(aes(x=x, y=y)) +
  geom_line(aes(x=x_tst, y=y_mean_g, colour="red")) +
  geom_line(aes(x=x_tst, y=y_mean_t, colour="darkblue")) +
  scale_color_discrete(name = "likelihood", labels = c("Student t", "Gaussian"))
plot(g)
ggsave(file = "summary.png", plot = g)

gaussian.stan

data {
  int<lower=0> N;
  vector[N] x;
  vector[N] y;
}
parameters {
  real alpha;
  real beta;
  real<lower=0> sigma;
}
model {
  y ~ normal(alpha + beta * x, sigma);
}

student_t.stan

data {
  int<lower=0> N;
  vector[N] x;
  vector[N] y;
}
parameters {
  real alpha;
  real beta;
  real nu;
  real<lower=0> sigma;
}
model {
  y ~ student_t(nu, alpha + beta * x, sigma);
}

参考