##Forecasting at the aggregate level
library(ggplot2)
library(dplyr)
library(sp)
bpy_org <- bpy.colors(15)[12]
bpy_blue <- bpy.colors(15)[3]
windowsFonts(Calibri=windowsFont("Calibri"))
load("chap4/data/ONS2018_allc.RData")

#Indicate models fitted
indmod <- c("50/50","33/67","25/75","20/80","10/90","100/0")
indTF <- c(T, T, T, T, T, T)  # change T to F depending on models fitted
mod2 <- indmod[which(indTF)]
type <- c(2:6,1)[which(indTF)]

#ONS data set-up
data4 <- list()
for (i in 1:4) {
  data4[[i]] <- data.frame(age=ONS_births_dat[[i]]$Group.1,coh=ONS_births_dat[[i]]$Group.2,N=ONS_expos_dat[[i]]$x,n=ONS_births_dat[[i]]$x,ONS=ONS_rates_dat[[i]]$x)
  data4[[i]] <- data4[[i]][which(data4[[i]]$N>0 & data4[[i]]$n>=0 & data4[[i]]$n <= data4[[i]]$N & data4[[i]]$age %in% c(15:44)),]
}
data4l <- rbind(data.frame(parity=0,data4[[1]]),
                data.frame(parity=1,data4[[2]]),
                data.frame(parity=2,data4[[3]]),
                data.frame(parity=3,data4[[4]]))
stepdat <- data.frame(y=0.5+44:(2018-2003))
stepdat$x <- 2019-stepdat$y

dat_expos <- expand.grid(age=15:44,coh=1945:2003,parity=0:3)
dat_expos <- left_join(dat_expos,data4l[,1:4])
dat_expos[which(is.na(dat_expos$N) & dat_expos$age+dat_expos$coh<=2018),"N"] <- 0
ONScfr <- apply(ONS_rates_mat[[5]],2,sum,na.rm=T)

#Forecast using probability samples
files <- paste0("chap4/results/ACprob_",gsub("/","",mod2),".RData")
dat_exposACP <- dat_exposAC <- cfrdat <- list()
set.seed(1)
for (j in 1:length(files)) {
  print(j)
  dat_exposACP[[j]] <- dat_exposAC[[j]] <- list()
  cfrdat[[j]] <- matrix(0,59,1000)
  load(files[j])
  for (i in 1:1000) {
    ex0 <- ONS_expos_mat[[1]][paste0(15:44),]
    ex1 <- ONS_expos_mat[[2]][paste0(15:44),]
    ex2 <- ONS_expos_mat[[3]][paste0(15:44),]
    ex3 <- ONS_expos_mat[[4]][paste0(15:44),]; ex3[ex3<0] <- 0
    bt0 <- ONS_births_mat[[1]][paste0(15:44),]; bt0[bt0>0] <- NA
    bt1 <- bt2 <- bt3 <- bt0
    for (cohort in 1975:2003) {
      pm <- matrix(unlist(psurvf[[i]][psurvf[[i]]$coh==cohort,paste0("p",1:4)]),ncol=4)
      rownames(pm) <- 15:44
      ex <- cbind(ex0[,paste(cohort)],ex1[,paste(cohort)],ex2[,paste(cohort)],ex3[,paste(cohort)])
      for (a in 1:sum(is.na(ex[,1]))) {
        a0 <- paste(2018-cohort+(a-1))
        a1 <- paste(2018-cohort+a)
        b0 <- rbinom(4,ex[a0,],pm[a0,])
        s0 <- ex[a0,]-b0
        ex[a1,1] <- s0[1]
        ex[a1,2] <- b0[1]+s0[2]
        ex[a1,3] <- b0[2]+s0[3]
        ex[a1,4] <- b0[3]+ex[a0,4]
        ex0[a1,paste(cohort)] <- ex[a1,1]
        ex1[a1,paste(cohort)] <- ex[a1,2]
        ex2[a1,paste(cohort)] <- ex[a1,3]
        ex3[a1,paste(cohort)] <- ex[a1,4]
        bt0[a0,paste(cohort)] <- b0[1]
        bt1[a0,paste(cohort)] <- b0[2]
        bt2[a0,paste(cohort)] <- b0[3]
        bt3[a0,paste(cohort)] <- b0[4]
      }
      a <- a+1
      a0 <- paste(2018-cohort+(a-1))
      b0 <- rbinom(4,ex[a0,],pm[a0,])
      bt0[a0,paste(cohort)] <- b0[1]
      bt1[a0,paste(cohort)] <- b0[2]
      bt2[a0,paste(cohort)] <- b0[3]
      bt3[a0,paste(cohort)] <- b0[4]      
    }
    dat_exposACP[[j]][[i]] <- dat_expos
    dat_exposACP[[j]][[i]]$N[dat_exposACP[[j]][[i]]$parity==0] <- as.vector(ex0)
    dat_exposACP[[j]][[i]]$N[dat_exposACP[[j]][[i]]$parity==1] <- as.vector(ex1)
    dat_exposACP[[j]][[i]]$N[dat_exposACP[[j]][[i]]$parity==2] <- as.vector(ex2)
    dat_exposACP[[j]][[i]]$N[dat_exposACP[[j]][[i]]$parity==3] <- as.vector(ex3)
    dat_exposACP[[j]][[i]]$b <- NA
    dat_exposACP[[j]][[i]]$b[dat_exposACP[[j]][[i]]$parity==0] <- as.vector(bt0)
    dat_exposACP[[j]][[i]]$b[dat_exposACP[[j]][[i]]$parity==1] <- as.vector(bt1)
    dat_exposACP[[j]][[i]]$b[dat_exposACP[[j]][[i]]$parity==2] <- as.vector(bt2)
    dat_exposACP[[j]][[i]]$b[dat_exposACP[[j]][[i]]$parity==3] <- as.vector(bt3)
    naind1 <- which(is.na(as.vector(bt0)))
    naind2 <- which(is.na(dat_exposACP[[j]][[i]]$b))
    dat_exposACP[[j]][[i]]$b[naind2] <- rbinom(length(naind2),dat_exposACP[[j]][[i]]$N[naind2],unlist(psurvf[[i]][naind1,paste0("p",1:4)]))
    dat_exposAC[[j]][[i]] <- data.frame(expand.grid(age=15:44,coh=1945:2003),
                                        bir=aggregate(dat_exposACP[[j]][[i]]$b,by=list(dat_exposACP[[j]][[i]]$age,dat_exposACP[[j]][[i]]$coh),sum)$x,
                                        exp=as.vector(ex0+ex1+ex2+ex3))
    dat_exposAC[[j]][[i]]$rate <- dat_exposAC[[j]][[i]]$bir/dat_exposAC[[j]][[i]]$exp
    cfrdat[[j]][,i] <- aggregate(dat_exposAC[[j]][[i]]$rate,by=list(dat_exposAC[[j]][[i]]$coh),sum)$x
  }
}

#Labels
supp.labst <- c("100% UKHLS, 0% ONS",
                "50% UKHLS, 50% ONS",
                "33% UKHLS, 67% ONS",
                "25% UKHLS, 75% ONS",
                "20% UKHLS, 80% ONS",
                "10% UKHLS, 90% ONS")
names(supp.labst) <- 1:6

#Prepare for plots
#CFR plot
cfrfunc <- function(t,j) {
  dat <- data.frame(type=t,
                    coh=1945:2003,
                    cfr=apply(cfrdat[[j]],1,mean),
                    q5=apply(cfrdat[[j]],1,quantile,p=0.05),
                    q95=apply(cfrdat[[j]],1,quantile,p=0.95),
                    q25=apply(cfrdat[[j]],1,quantile,p=0.25),
                    q75=apply(cfrdat[[j]],1,quantile,p=0.75))
  dat
}

cfrplot <- numeric()
for (j in 1:length(files)) {
  cfrplot <- rbind(cfrplot,cfrfunc(type[j],j))
}
cfrplot$type <- factor(supp.labst[cfrplot$type],levels=supp.labst)
cfrplot$ONScfr <- rep(ONScfr,length(type))

cfrplot1 <- data.frame(type=rep(cfrplot$type,2),
                       coh=rep(cfrplot$coh,2),
                       lower=c(cfrplot$q5,cfrplot$q25),
                       upper=c(cfrplot$q25,cfrplot$q75),
                       int=rep(c("90%","50%"),each=nrow(cfrplot)),
                       cfr=rep(cfrplot$ONScfr,2))
cfrplot1$int <- factor(cfrplot1$int,levels=c("90%","50%"))
cfrplot1$cfr[cfrplot1$coh>1974] <- NA
cfrplot1$type <- factor(supp.labst[cfrplot1$type],levels=supp.labst)

cfrplot2 <- data.frame(type=rep(cfrplot$type,2),
                       coh=rep(cfrplot$coh,2),
                       lower=c(cfrplot$q75,cfrplot$q25),
                       upper=c(cfrplot$q95,cfrplot$q75),
                       int=rep(c("90%","50%"),each=nrow(cfrplot)))
cfrplot2$int <- factor(cfrplot2$int,levels=c("90%","50%"))
cfrplot2[cfrplot2$int=="50%",] <- NA
cfrplot2$type <- factor(supp.labst[cfrplot2$type],levels=supp.labst)

intdat <- data.frame(type=supp.labst,
                     xintercept=c(2007,rep(2018,5))-44)
intdat$type <- factor(intdat$type,levels=supp.labst)
intdat <- intdat[intdat$type %in% supp.labst[type],]

#ASFR plot
asfrfunc <- function(t,j) {
  dat <- data.frame(type=t,
                    expand.grid(age=15:44,coh=1945:2003),
                    asfr=Reduce("+",lapply(dat_exposAC[[j]],function(x) x$rate))/1000)
  dat
}

asfrplot <- numeric()
for (j in 1:length(files)) {
  asfrplot <- rbind(asfrplot,asfrfunc(type[j],j))
}

asfrplotp <- asfrplotf <- asfrplot
asfrplotp[which(asfrplotp$age+asfrplotp$coh>2018 & asfrplotp$type %in% c(2:6)),] <- NA
asfrplotf[which(asfrplotf$age+asfrplotf$coh<=2018 & asfrplotf$type %in% c(2:6)),] <- NA
asfrplotp[which(asfrplotp$age+asfrplotp$coh>2007 & asfrplotp$type == 1),] <- NA
asfrplotf[which(asfrplotf$age+asfrplotf$coh<=2007 & asfrplotf$type == 1),] <- NA
asfrplot$type <- factor(supp.labst[asfrplot$type],levels=supp.labst)

#Plots
#CFR posterior distributions (Figure 4.36)
png(file=paste0("chap4/plots/fig36.png"),width=30,height=15,units="cm",res=400)
ggplot(cfrplot1, aes(x=coh, y=lower, fill=int)) +
  geom_ribbon(aes(ymin=lower,
                  ymax=upper,
                  fill=int,
                  color=model),color=NA) + 
  geom_ribbon(aes(ymin=cfrplot2$lower,
                  ymax=cfrplot2$upper,
                  fill=cfrplot2$int,
                  color=cfrplot2$model),color=NA, show.legend=F) +
  geom_vline(data=intdat,aes(xintercept=xintercept)) +
  geom_point(aes(y=cfr,shape="ONS"),show.legend = T) +
  scale_shape_manual(values=c("ONS" = 19),guide_legend(order=3,title="Observed rates")) +
  labs(x = "Cohort", y = "CFR") +
  scale_x_continuous(limits=c(1945,2003), breaks=seq(1945,2000,5), minor_breaks=setdiff(1945:2003,seq(1945,2000,5)), expand=c(0.02,0.02)) +
  scale_fill_manual("Credible interval",guide=guide_legend(order=1,override.aes = list(pch = c(NA,NA))),
                    values=alpha(c("blue","blue"),c(0.3,0.5)),na.translate=F) +
  theme_bw() + theme(axis.text.x = element_text(angle = 90, hjust=1, vjust=0.3), text = element_text("Calibri"), legend.position = "bottom") +
  facet_wrap(~type, nrow=2)
dev.off()

#ASFR posterior means (Figure 4.38)
png(file="chap4/plots/fig38.png",width=30,height=15,units="cm",res=400)
ggplot(asfrplot, aes(x=age, y=asfr, color=coh)) +
  geom_line(aes(x=asfrplotp$age, y=asfrplotp$asfr, group=asfrplotp$coh)) +
  geom_line(aes(x=asfrplotf$age, y=asfrplotf$asfr, group=asfrplotf$coh), linetype=2) +
  labs(x = "Age", y = "ASFR", color="Cohort")+
  scale_color_gradientn(colours=rainbow(100, start=0.3, end=1), guide = guide_colorbar(barheight = 20,frame.colour="black",ticks.colour="black"), breaks = seq(1945,2000,5), limits=c(1945,2003)) +
  scale_x_continuous(limits=c(15,44), breaks=seq(15,44,5), minor_breaks=setdiff(15:44,seq(15,44,5)), expand=c(0.02,0.02)) +
  coord_cartesian(ylim=c(0,0.2)) +
  theme_bw() + theme(text = element_text("Calibri")) +
  facet_wrap(~type,nrow=2)
dev.off()
