You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

126 lines
3.7 KiB

3 years ago
  1. library(ggplot2)
  2. library(gganimate)
  3. library(patchwork)
  4. gradient <- function(f, x, d){
  5. return((f(x + d) - f(x - d)) / (2*d))
  6. }
  7. gradient.ascent.move <- function(f, x, d, mu){
  8. return(x + mu * gradient(f, x, d))
  9. }
  10. func.g <- function(x) x
  11. func.k <- function(x) sin(x)
  12. func.h <- function(x) x * sin(x)
  13. func.l <- function(x) 2 + cos(x) + sin(2*x)
  14. gradient.ascent.iterate <- function(f, x, d, mu, n){
  15. if(n == 1) {
  16. return(gradient.ascent.move(f, x, d, mu))
  17. }
  18. return(gradient.ascent.niter(f
  19. ,gradient.descent.move(f, x, d, mu)
  20. ,d
  21. ,mu
  22. ,n-1
  23. ))
  24. }
  25. gradient.ascent.iterverb <- function(f, x, d, mu, n, xs=numeric()){
  26. next_x <- gradient.ascent.move(f, x, d, mu)
  27. xs[length(xs)+1] <- next_x
  28. if(n == 1) {
  29. return(xs)
  30. }
  31. return(gradient.ascent.iterverb(f, next_x, d, mu, n-1, xs))
  32. }
  33. trace.ascent <- function(f, x, d, eta, n, xs) {
  34. df_dc = data.frame(x=numeric()
  35. ,y=numeric()
  36. ,i=integer()
  37. ,start_x=character()
  38. ,eta=numeric())
  39. for(start in x) {
  40. for(e in eta) {
  41. first_it <- TRUE
  42. if(first_it == TRUE) {
  43. df_dc <- rbind(df_dc, data.frame(x=c(start)
  44. ,y=c(f(start))
  45. ,i=c(0)
  46. ,start_x=c(as.character(start))
  47. ,eta=c(e)
  48. ))
  49. first_it <- FALSE
  50. }
  51. xf <- gradient.ascent.iterverb(f, start, d, e, n)
  52. df_dc <- rbind(df_dc, data.frame(x=xf
  53. ,y=f(xf)
  54. ,i=1:length(xf)
  55. ,start_x=rep(as.character(start), length(xf))
  56. ,eta=e)
  57. )
  58. }
  59. }
  60. return(df_dc)
  61. }
  62. plot.ascent <- function(f, x, d, eta, n, xs) {
  63. df_dc = trace.ascent(f, x, d, eta, n, xs)
  64. func_str = deparse(substitute(f))
  65. df_f <- data.frame(x=xs, y=f(xs))
  66. p1 <- ggplot(df_f, aes(x=x, y=y)) +
  67. geom_line() +
  68. geom_point(aes(colour=start_x
  69. ,size=i
  70. )
  71. ,data=df_dc) +
  72. labs(size="iteration"
  73. ,alpha="iteration"
  74. ,color="start x"
  75. ,y=sprintf("%s(x)", func_str)) +
  76. facet_grid(eta ~ ., labeller=label_both)
  77. p2 <- ggplot(df_dc, aes(x=i, y=y)) +
  78. geom_line(aes(colour=start_x), show.legend=FALSE) +
  79. labs(x="iteration"
  80. ,y=sprintf("%s(x)", func_str)) +
  81. facet_grid(eta ~ ., labeller=label_both)
  82. p <- (p1 | p2) +
  83. plot_annotation(title=sprintf("function: %s", func_str)) +
  84. plot_layout(guides="collect"
  85. ,widths=10
  86. ,heights=2)
  87. return(p)
  88. }
  89. animate.ascent <- function(f, x, d, eta, n, xs) {
  90. df_dc = trace.ascent(f, x, d, eta, n, xs)
  91. func_str <- deparse(substitute(f))
  92. df_f <- data.frame(x=xs, y=f(xs))
  93. p <- ggplot(df_f, aes(x=x, y=y)) +
  94. geom_line() +
  95. geom_point(aes(colour=start_x), size=2.5, data=df_dc) +
  96. labs(color="start x", y=sprintf("%s(x)", func_str)) +
  97. facet_grid(eta ~ ., labeller=label_both) +
  98. ggtitle(sprintf("function: %s", func_str))
  99. anim <- p + transition_reveal(i)
  100. return(anim)
  101. }