##Backtesting results
library(ggplot2)
library(dplyr)
library(boot)
library(rstan)
windowsFonts(Calibri=windowsFont("Calibri"))
load("chap4/data/ONS2018_allc.RData")
load("chap3/results/Qmeanimp.RData")
source("chap4/scripts/important_functions.r")

#Indicate integrated model type fitted with back = TRUE and back = FALSE to each parity
mod <- c("50/50","33/67","25/75","20/80","10/90")[2]

#ONS data set-up
data4 <- list()
for (i in 1:4) {
  data4[[i]] <- data.frame(age=ONS_births_dat[[i]]$Group.1,coh=ONS_births_dat[[i]]$Group.2,N=ONS_expos_dat[[i]]$x,n=ONS_births_dat[[i]]$x,ONS=ONS_rates_dat[[i]]$x)
  data4[[i]] <- data4[[i]][which(data4[[i]]$N>0 & data4[[i]]$n>=0 & data4[[i]]$n <= data4[[i]]$N & data4[[i]]$age %in% c(15:44)),]
}
data4l <- rbind(data.frame(parity=0,data4[[1]]),
                data.frame(parity=1,data4[[2]]),
                data.frame(parity=2,data4[[3]]),
                data.frame(parity=3,data4[[4]]))

#Set-up
files <- c(paste0("chap4/results/p",c(0,1,2,3),"_",gsub("/","",mod),".RData"),
           paste0("chap4/results/p",c(0,1,2,3),"_",gsub("/","",mod),"_2013.RData"))
parity <- rep(0:3,2)
type <- rep(1:2, each=4)
results <- list()
results$Yhatm <- results$Adat <- results$AQAQdat <- results$Cdat <- results$Tdat <- results$ACACdat <- results$ACdat <- results$fdat <- results$probdat <- list()
results$psurvf <- results$psurvfm <- results$psurvfl <- results$psurvfu <- results$psumfm <- results$psumfl <- results$psumfu <- results$wsumfm <- results$wsumfl <- results$wsumfu <- list()

#Extract results from fitted models
for (n in 1:length(files)) {
  print(n)
  p <- parity[n]
  if (p == 0) source("chap4/scripts/p0_setup.r")
  if (p == 1) source("chap4/scripts/p1_setup.r")
  if (p == 2) source("chap4/scripts/p2_setup.r")
  if (p == 3) source("chap4/scripts/p3_setup.r")
  ind <- -c(1:64)
  load(files[n])
  source(paste0("chap4/scripts/p",p,"clean.r"))
}

#Processing
#Smooth terms
agerange <- 15:44
cohrange <- 1945:2003
Adat <- data.frame(type=rep(1:2,each=4*length(agerange)),Reduce(rbind,results$Adat))
Cdat <- data.frame(type=rep(1:2,each=4*length(cohrange)),Reduce(rbind,results$Cdat))
Tdat <- data.frame(type=rep(1:2,each=3*11),Reduce(rbind,results$Tdat))
ACACdat <- data.frame(type=rep(1:2,each=4*length(agerange)*length(cohrange)),
                      Reduce(rbind,results$ACACdat))
ACACdatp <- ACACdatf <- ACACdat
ACACdatp[which(ACACdatp$age+ACACdatp$coh>2018 & ACACdat$type==1),] <- NA
ACACdatf[which(ACACdatf$age+ACACdatf$coh<=2018 & ACACdat$type==1),] <- NA
ACACdatp[which(ACACdatp$age+ACACdatp$coh>2013 & ACACdat$type==2),] <- NA
ACACdatf[which(ACACdatf$age+ACACdatf$coh<=2013 & ACACdat$type==2),] <- NA
ACdat <- data.frame(type=rep(1:2,each=4*length(agerange)*length(cohrange)),
                    Reduce(rbind,results$ACdat))
ACdatp <- ACdatf <- ACdat
ACdatp[which(ACdatp$age+ACdatp$coh>2018 & ACdat$type==1),] <- NA
ACdatf[which(ACdatf$age+ACdatf$coh<=2018 & ACdat$type==1),] <- NA
ACdatp[which(ACdatp$age+ACdatp$coh>2013 & ACdat$type==2),] <- NA
ACdatf[which(ACdatf$age+ACdatf$coh<=2013 & ACdat$type==2),] <- NA
fdatACall <- data.frame(type=rep(1:2,each=4*length(agerange)*length(cohrange)),
                        parity=rep(parity,each=length(agerange)*length(cohrange)),
                        age=rep(rep(agerange,length(cohrange)),length(files)),
                        coh=rep(rep(cohrange,each=length(agerange)),length(files)),
                        mean=unlist(results$psurvfm),
                        lower=unlist(results$psurvfl),
                        upper=unlist(results$psurvfu))
fdatACall <- left_join(fdatACall,data4l[,c("parity","age","coh","ONS")],by=c("parity","age","coh"))
fdatACall <- rbind(fdatACall,data.frame(data4l[,c("parity","age","coh")],mean=data4l$ONS,lower=NA,upper=NA,ONS=NA,type=3))

#Forecast exposures
expf0 <- expf1 <- expf2 <- expf3 <- data.frame(expand.grid(age=agerange,coh=cohrange))
ONSexp0 <- data4[[1]][,c(1,2,3)]
ONSexp1 <- data4[[2]][,c(1,2,3)]
ONSexp2 <- data4[[3]][,c(1,2,3)]
ONSexp3 <- data4[[4]][,c(1,2,3)]
names(ONSexp0)[3] <- names(ONSexp1)[3] <- names(ONSexp2)[3] <- names(ONSexp3)[3] <- "ON"
expf0 <- left_join(expf0,ONSexp0)
expf1 <- left_join(expf1,ONSexp1)
expf2 <- left_join(expf2,ONSexp2)
expf3 <- left_join(expf3,ONSexp3)
expf0$ON[which(is.na(expf0$ON) & expf0$age+expf0$coh<=2018)] <- 0
expf1$ON[which(is.na(expf1$ON) & expf0$age+expf0$coh<=2018)] <- 0
expf2$ON[which(is.na(expf2$ON) & expf0$age+expf0$coh<=2018)] <- 0
expf3$ON[which(is.na(expf3$ON) & expf0$age+expf0$coh<=2018)] <- 0

for (i in agerange) {
  expf0$ON[which(is.na(expf0$ON) & expf0$age==i)] <- tail(expf0$ON[which(!is.na(expf0$ON) & expf0$age==i)],1)
  expf1$ON[which(is.na(expf1$ON) & expf1$age==i)] <- tail(expf1$ON[which(!is.na(expf1$ON) & expf1$age==i)],1)
  expf2$ON[which(is.na(expf2$ON) & expf2$age==i)] <- tail(expf2$ON[which(!is.na(expf2$ON) & expf2$age==i)],1)
  expf3$ON[which(is.na(expf3$ON) & expf3$age==i)] <- tail(expf3$ON[which(!is.na(expf3$ON) & expf3$age==i)],1)
}
expf <- rbind(data.frame(parity=0,expf0),
              data.frame(parity=1,expf1),
              data.frame(parity=2,expf2),
              data.frame(parity=3,expf3))

#Incorporate additional binomial uncertainty
for (n in 1:length(files)) {
  p <- parity[n]
  t <- type[n]
  psurvfn <- results$psurvf[[n]]
  Opsurvfr <- psurvfn
  for (i in 1:1000) {
    Opsurvfr[i,] <- rbinom(length(agerange)*length(cohrange),expf$ON[expf$parity==p],psurvfn[i,])/expf$ON[expf$parity==p]
  }
  resdat <- data.frame(type=t,parity=p,age=rep(agerange,length(cohrange)),
                       coh=rep(cohrange,each=length(agerange)),
                       meanO=apply(Opsurvfr,2,mean),
                       lowerO=apply(Opsurvfr,2,quantile,p=0.025,na.rm=T),
                       upperO=apply(Opsurvfr,2,quantile,p=0.975,na.rm=T))
  if (n == 1) fdatACall <- left_join(fdatACall,resdat)
  if (n > 1) fdatACall[fdatACall$type==t&fdatACall$parity==p,9:11] <- resdat[,5:7]
}

fdatACall$type <- factor(fdatACall$type,levels=1:3)
fdatACallp <- fdatACallf <- fdatACall
fdatACallp[which(fdatACallp$age+fdatACallp$coh>2018 & fdatACall$type %in% c(1,3)),] <- NA
fdatACallf[which(fdatACallf$age+fdatACallf$coh<=2018 & fdatACall$type %in% c(1,3)),] <- NA
fdatACallp[which(fdatACallp$age+fdatACallp$coh>2013 & fdatACall$type==2),] <- NA
fdatACallf[which(fdatACallf$age+fdatACallf$coh<=2013 & fdatACall$type==2),] <- NA
fdatACallred <- fdatACall[which(!is.na(fdatACall$ONS) | fdatACall$type==3),]
psumONS <- aggregate(data4l$ONS,by=list(data4l$coh,data4l$parity),sum)
psumONS[which(psumONS$Group.1>(2018-44)),"x"] <- NA
colnames(psumONS) <- c("cov","parity","ONS")
psumdat <- data.frame(type=rep(1:2,each=4*length(cohrange)),
                      parity=rep(parity,each=length(cohrange)),
                      cov=rep(cohrange,length(files)),
                      mean=unlist(results$psumfm),
                      lower=unlist(results$psumfl),
                      upper=unlist(results$psumfu))
psumdat <- left_join(psumdat,psumONS)
psumdatp <- psumdatf <- psumdat
psumdatp[which(psumdat$cov+44>2018),] <- NA
psumdatf[which(psumdat$cov+44<=2018),] <- NA

#Plots
#Labels
supp.labsp <- paste0("Parity ",c(0,1,2,"3+"))
names(supp.labsp) <- c(0,1,2,3)
supp.labst <- c("ONS data to 2018","ONS data to 2013","ONS data")
names(supp.labst) <- 1:3

#Posterior mean marginalised probabilities (Figure 4.29)
png(file="chap4/plots/fig29.png",width=20,height=15,units="cm",res=400)
pind <- fdatACall$parity>=0
ggplot(fdatACall[pind,], aes(x=age, y=mean, color=coh)) +
  geom_line(aes(x=fdatACallp$age[pind], y=fdatACallp$mean[pind], group=fdatACallp$coh[pind]),size=0.3) +
  geom_line(aes(x=fdatACallf$age[pind], y=fdatACallf$mean[pind], group=fdatACallf$coh[pind]), linetype=2,size=0.3) +
  labs(x = "Age", y = "Fitted probability/Observed rate", color="Cohort")+
  scale_color_gradientn(colours=rainbow(100, start=0.3, end=1), guide = guide_colorbar(barheight = 20,frame.colour="black",ticks.colour="black"), breaks = seq(1945,2000,5), limits=c(1945,2003)) +
  scale_x_continuous(limits=c(15,44), breaks=seq(15,44,5), minor_breaks=setdiff(15:44,seq(15,44,5)), expand=c(0.02,0.02)) +
  coord_cartesian(ylim=c(0,0.35)) +
  theme_bw() + theme(text = element_text("Calibri")) +
  facet_grid(type~parity,labeller=labeller(parity=supp.labsp,type=supp.labst))
dev.off()

#Cohort main effects (Figure 4.30)
Cdat <- data.frame(type=rep(1:2,each=4*length(cohrange)),Reduce(rbind,results$Cdat))
Cdat$type <- factor(supp.labst[Cdat$type],levels=supp.labst[c(1,2)])
Cdat2 <- Cdat4 <- Cdat
Cdat2[Cdat2$type!=supp.labst[1],-c(1)] <- NA
Cdat4[Cdat4$type!=supp.labst[2],-c(1)] <- NA

png(file="chap4/plots/fig30.png",width=20,height=9.35,units="cm",res=400)
ind <- which(Cdat$type %in% supp.labst[c(1,2)])# & Adat$parity>1)
ggplot(Cdat[ind,], aes(x=cov, y=fit, color=type)) +
  geom_line(aes(x=Cdat2$cov[ind],y=Cdat2$fit[ind],color=Cdat2$type[ind]),size=1) +
  geom_line(aes(x=Cdat2$cov[ind],y=Cdat2$lower[ind],color=Cdat2$type[ind]),size=0.5) +
  geom_line(aes(x=Cdat2$cov[ind],y=Cdat2$upper[ind],color=Cdat2$type[ind]),size=0.5) +
  geom_line(aes(x=Cdat4$cov[ind],y=Cdat4$fit[ind],color=Cdat4$type[ind]),size=1) +
  geom_line(aes(x=Cdat4$cov[ind],y=Cdat4$lower[ind],color=Cdat4$type[ind]),size=0.5) +
  geom_line(aes(x=Cdat4$cov[ind],y=Cdat4$upper[ind],color=Cdat4$type[ind]),size=0.5) +
  labs(x = "Cohort (c)", y = expression(paste(f[C],"(c)",sep="")), color="Model")+
  scale_y_continuous(breaks=seq(-4,4,1)) +
  scale_x_continuous(limits=c(1945,2003), breaks=seq(1950,2000,10), minor_breaks=seq(1945,2005,10), expand=c(0.02,0.02)) +
  coord_cartesian(ylim=c(-3,4)) +
  theme_gray() + theme(text = element_text("Calibri"), legend.position = "bottom") +
  facet_wrap(~parity, nrow=1, labeller=labeller(parity=supp.labsp))
dev.off()

#Plot marginalised probabilities with predictive uncertainty (Figure 4.31)
fdatACall$shape <- ifelse(fdatACall$type==1 | (fdatACall$type==2 & fdatACall$age+fdatACall$coh<=2013),1,2)
fdatACall$shape <- factor(c("Training","Test")[fdatACall$shape],levels=c("Training","Test"))

plotfunc <- function(pind,filename,ylim) {
  png(file=filename,width=30,height=15,units="cm",res=400)
  print({ ggplot(fdatACall[pind,], aes(x=age, y=meanO, color=coh)) +
      geom_line(aes(x=fdatACallp$age[pind],
                    y=fdatACallp$meanO[pind],
                    group=fdatACallp$coh[pind]),
                size=0.8) +
      geom_line(aes(x=fdatACallf$age[pind],
                    y=fdatACallf$meanO[pind],
                    group=fdatACallf$coh[pind]),
                linetype=2,size=0.8) +
      geom_line(aes(x=fdatACallp$age[pind],
                    y=fdatACallp$lowerO[pind],
                    group=fdatACallp$coh[pind]),
                size=0.3) +
      geom_line(aes(x=fdatACallf$age[pind],
                    y=fdatACallf$lowerO[pind],
                    group=fdatACallf$coh[pind]),
                linetype=2,size=0.3) +  
      geom_line(aes(x=fdatACallp$age[pind],
                    y=fdatACallp$upperO[pind],
                    group=fdatACallp$coh[pind]),
                size=0.3) +
      geom_line(aes(x=fdatACallf$age[pind],
                    y=fdatACallf$upperO[pind],
                    group=fdatACallf$coh[pind]),
                linetype=2,size=0.3) +
      geom_vline(aes(xintercept=xintercept,color=coh),data=intdat) +
      geom_ribbon(aes(ymin=lowerO,ymax=upperO,fill=coh),color=NA,alpha=0.15,show.legend=F) + 
      geom_point(aes(y=ONS,shape=shape)) +
      scale_shape_manual(values=c("Training" = 1,"Test"=19), guide_legend(title="ONS data",order=1)) +
      labs(x = "Age", y = "Probability", color="Cohort (c)")+
      scale_fill_gradientn(colours=rainbow(100, start=0.3, end=1), guide = guide_colorbar(barheight = 20,frame.colour="black",ticks.colour="black"), breaks = seq(1945,2000,5), limits=c(1945,2003)) +
      scale_color_gradientn(colours=rainbow(100, start=0.3, end=1), guide = guide_colorbar(barheight = 20,frame.colour="black",ticks.colour="black"), breaks = seq(1945,2000,5), limits=c(1945,2003)) +
      scale_x_continuous(limits=c(15,44), breaks=seq(15,44,5), minor_breaks=setdiff(15:44,seq(15,44,5)), expand=c(0.02,0.02)) +
      coord_cartesian(ylim=c(0,ylim)) +
      theme_bw() + theme(text = element_text("Calibri")) +
      facet_grid(parity+type~paste0("c = ",coh), labeller=labeller(parity=supp.labsp,type=supp.labst)) })
  dev.off()
}

cohs <- seq(1972,2000,4)
intdat <- rbind(data.frame(type=2,coh=cohs,xintercept=2013-cohs),
                data.frame(type=1,coh=cohs,xintercept=2018-cohs))
plotfunc(fdatACall$p %in% c(0,1) & fdatACall$type %in% c(1,2) & fdatACall$coh %in% cohs,
         "chap4/plots/fig31.png", 0.35)
plotfunc(fdatACall$p %in% c(2,3) & fdatACall$type %in% c(1,2) & fdatACall$coh %in% cohs,
         "chap4/plots/fig32.png", 0.6)
