##Modelling Q|A,C
library(dplyr)
library(ggplot2)
library(RColorBrewer)
library(rstan)
windowsFonts(Calibri=windowsFont("Calibri"))

#Set-up
supp.labsq4 <- c("< GCSE","GCSE","A Level","Degree")
names(supp.labsq4) <- c(1,2,3,4)
supp.labsq2b <- c("< A Level","At least A Level")
names(supp.labsq2b) <- c(1,2)
supp.labsq3b <- c("< GCSE","GCSE/A Level","Degree")
names(supp.labsq3b) <- c(1,2,3)
supp.labsp <- c("Parity 0","Parity 1","Parity 2")
names(supp.labsp) <- c(0,1,2)
stepdat <- data.frame(y=0.5+44:(2007-1982))
stepdat$x <- 2008-stepdat$y
stepdat <- rbind(stepdat,c(14.5,1982.5))
agerange <- 15:44
cohrange <- 1945:2003

#Compute observed proportions
#Parity 0 function
ACQtab0func <- function(FUN) {
  ACQtab0 <- aggregate(weights_0st, by=list(age_0,coh_0,qualf4_0), FUN = FUN, drop=FALSE)
  ACQtab0 <- ACQtab0[which(ACQtab0$Group.1+ACQtab0$Group.2<=2007 & ACQtab0$Group.2<=1982),]
  ACQtab0 <- left_join(ACQtab0, aggregate(weights_0st, by=list(age_0,coh_0), FUN = FUN), by=c("Group.1","Group.2"))
  ACQtab0$x <- ACQtab0$x.x/ACQtab0$x.y
  ACQtab0$qualft <- supp.labsq4[ACQtab0$Group.3]
  ACQtab0$Parity <- 0
  ACQtab0$qualft <- factor(ACQtab0$qualft, levels=supp.labsq4)
  ACQtab0$type <- 1
  ACQtab0
}

#Parity 1 function
ACQtab1func <- function(FUN) {
  ACQtab1 <- aggregate(weights_1st, by=list(age_1,coh_1,qualf2b_1), FUN = FUN, drop=FALSE)
  ACQtab1 <- ACQtab1[which(ACQtab1$Group.1+ACQtab1$Group.2<=2007 & ACQtab1$Group.2<=1982),]
  ACQtab1 <- left_join(ACQtab1, aggregate(weights_1st, by=list(age_1,coh_1), FUN = FUN), by=c("Group.1","Group.2"))
  ACQtab1$x.x[which(is.na(ACQtab1$x.x) & !is.na(ACQtab1$x.y))] <- 0
  ACQtab1$x <- ACQtab1$x.x/ACQtab1$x.y
  ACQtab1$qualft <- supp.labsq2b[ACQtab1$Group.3]
  ACQtab1$Parity <- 1
  ACQtab1$qualft <- factor(ACQtab1$qualft, levels=supp.labsq2b)
  ACQtab1$type <- 1  
  ACQtab1
}

#Parity 2 function
ACQtab2func <- function(FUN) {
  ACQtab2 <- aggregate(weights_2st, by=list(age_2,coh_2,qualf3b_2), FUN = FUN, drop=FALSE)
  ACQtab2 <- ACQtab2[which(ACQtab2$Group.1+ACQtab2$Group.2<=2007 & ACQtab2$Group.2<=1982),]
  ACQtab2 <- left_join(ACQtab2, aggregate(weights_2st, by=list(age_2,coh_2), FUN = FUN), by=c("Group.1","Group.2"))
  ACQtab2$x.x[which(is.na(ACQtab2$x.x) & !is.na(ACQtab2$x.y))] <- 0
  ACQtab2$x <- ACQtab2$x.x/ACQtab2$x.y
  ACQtab2$qualft <- supp.labsq3b[ACQtab2$Group.3]
  ACQtab2$Parity <- 2
  ACQtab2$qualft <- factor(ACQtab2$qualft, levels=supp.labsq3b)
  ACQtab2$type <- 1  
  ACQtab2
}

ACQtab0 <- ACQtab0func(length)
ACQtab0w <- ACQtab0func(sum)
ACQtab1 <- ACQtab1func(length)
ACQtab1w <- ACQtab1func(sum)
ACQtab2 <- ACQtab2func(length)
ACQtab2w <- ACQtab2func(sum)

ACQtaball <- rbind(ACQtab0,ACQtab1,ACQtab2)
ACQtaball$qualft <- factor(ACQtaball$qualft, levels=c("< GCSE","< A Level","GCSE","GCSE/A Level","A Level","At least A Level","Degree"))
ACQtaball$Parity <- factor(ACQtaball$Parity, levels=c(0,1,2))
ACQtaballw <- rbind(ACQtab0w,ACQtab1w,ACQtab2w)
ACQtaballw$qualft <- factor(ACQtaballw$qualft, levels=c("< GCSE","< A Level","GCSE","GCSE/A Level","A Level","At least A Level","Degree"))
ACQtaballw$Parity <- factor(ACQtaballw$Parity, levels=c(0,1,2))

#Plot observed proportions (Figure 4.4)
plotfunc <- function(dat,filename) {
  png(file=filename,width=17.5,height=15,units="cm",res=400)
  print({ggplot(dat, aes(x=Group.2,y=Group.1,fill=x)) +
    geom_raster(hjust=0.5,vjust=0.5) +
    labs(x = "Cohort", y = "Age", fill="Proportion") +
    scale_fill_gradientn(colours=c("white",colorRampPalette(brewer.pal(9,"YlOrRd"))(10000),"black"), limits=c(0,1),guide = guide_colorbar(barheight = 20,frame.colour="black",ticks.colour="black"), breaks = seq(0,1,0.1)) +
    scale_x_continuous(expand=c(0,0),breaks=seq(1945,1990,5), minor_breaks=setdiff(1945:1992,seq(1945,1990,5))) + 
    scale_y_continuous(expand=c(0,0),breaks=seq(15,44,5), minor_breaks=setdiff(15:44,seq(15,44,5))) +
    coord_cartesian(ylim=c(14.5,44.5),xlim=c(1944.5,1982.5)) +
    geom_vline(xintercept=1945.5:1981.5,color="lightgray",size=0.25) +
    geom_hline(yintercept=15.5:43.5,color="lightgray",size=0.25) +
    geom_step(data=stepdat, mapping=aes(x=x,y=y),direction="vh") +
    theme_bw() + theme(panel.grid.major = element_blank(),panel.grid.minor = element_blank(),axis.text.x = element_text(angle = 90, vjust=0.45),text = element_text("Calibri")) +
    facet_wrap(~Parity+qualft,labeller=labeller(Parity=supp.labsp),nrow=3)
  })
  dev.off()
}

plotfunc(ACQtaball, "chap4/plots/fig4_uw.png")  # unweighted
plotfunc(ACQtaballw, "chap4/plots/fig4_w.png")  # weighted

#Stan modelling
#Parity 0
newdata0 <- aggregate(qualf4_0 ~ age_0 + coh_0, FUN = function(x) c(y=length(x), q1=length(x[x==1]), q2=length(x[x==2]), q3=length(x[x==3]), q4=length(x[x==4])), subset=coh_0 <= 1982)
newdata0 <- data.frame(newdata0$age_0,newdata0$coh_0,newdata0$qualf4_0)
colnames(newdata0) <- c("a","c","y","y1","y2","y3","y4")
newdata0$a <- newdata0$a - median(15:44)
newdata0$yw <- aggregate(weights_0st ~ age_0 + coh_0, FUN = sum, subset = coh_0 <= 1982)$weights_0st
newdata0$wtmult <- newdata0$yw/newdata0$y
colnames(newdata0) <- c("a","c","y",paste0(1:4),"yw","wtmult")
qmax0 <- 4

newdata02 <- expand.grid(a=agerange,c=cohrange)
newdata02$cc <- newdata02$c-median(1945:1982)
newdata02$cind72 <- ifelse(newdata02$c >= 1972, 1972, newdata02$c)
newdata02$ccind72 <- newdata02$cind72-1944
newdata02$cc72 <- ifelse(newdata02$c<=1971,0,newdata02$cc)
newdata02$aind <- newdata02$a-14
newdata02$a <- newdata02$a-median(15:44)

newdata02 <- left_join(newdata02,newdata0,by=c("a","c"))
Nac <- nrow(newdata02)
y <- newdata02[,paste0(1:qmax0)]
yind <- which(!is.na(y[,1]))
Nobs <- length(yind)
y[is.na(y)] <- 0
cc72 <- newdata02$cc72
ccind72 <- newdata02$ccind72
Nc72 <- max(ccind72)
a <- newdata02$a
aind <- newdata02$aind
Na <- max(aind)
wt <- newdata02$wtmult
wt[is.na(wt)] <- 0

standata <- list(Nac=Nac, Nc72=Nc72, Na=Na, Nobs=Nobs, y=y, ccind72=ccind72, cc72=cc72, aind=aind, a=a, yind=yind, wt=wt)
stanout <- stan(file="chap4/stan/p0_ACQ_5.2.stan",data=standata,chains=1,iter=2000)
save(stanout,file="chap4/results/p0_ACQ_5.2.RData")

#Parity 1
newdata1 <- aggregate(qualf2b_1 ~ age_1 + coh_1, FUN = function(x) c(y=length(x), q1=length(x[x==1]), q2=length(x[x==2])), subset=coh_1 <= 1982)
newdata1 <- data.frame(newdata1$age_1,newdata1$coh_1,newdata1$qualf2b_1)
colnames(newdata1) <- c("a","c","y","y1","y2")
newdata1$a <- newdata1$a - median(15:44)
newdata1$yw <- aggregate(weights_1st ~ age_1 + coh_1, FUN = sum, subset = coh_1 <= 1982)$weights_1st
newdata1$wtmult <- newdata1$yw/newdata1$y
colnames(newdata1) <- c("a","c","y",paste0(1:2),"yw","wtmult")
qmax1 <- 2

newdata12 <- expand.grid(a=agerange,c=cohrange)
newdata12$cc <- newdata12$c-median(1945:1982)
newdata12$cind72 <- ifelse(newdata12$c >= 1972, 1972, newdata12$c)
newdata12$ccind72 <- newdata12$cind72-1944
newdata12$cc72 <- ifelse(newdata12$c<=1971,0,newdata12$cc)
newdata12$aind <- newdata12$a-14
newdata12$a <- newdata12$a-median(15:44)

newdata12 <- left_join(newdata12,newdata1,by=c("a","c"))
Nac <- nrow(newdata12)
y <- newdata12[,paste0(1:qmax1)]
yind <- which(!is.na(y[,1]))
Nobs <- length(yind)
y[is.na(y)] <- 0
cc72 <- newdata12$cc72
ccind72 <- newdata12$ccind72
Nc72 <- max(ccind72)
aind <- newdata12$aind
Na <- max(aind)
wt <- newdata12$wtmult
wt[is.na(wt)] <- 0

standata <- list(Nac=Nac, Nc72=Nc72, Na=Na, Nobs=Nobs, y=y, ccind72=ccind72, cc72=cc72, aind=aind, yind=yind, wt=wt)
stanout <- stan(file="chap4/stan/p1_ACQ_7.2.stan",data=standata,chains=1,iter=2000)
save(stanout,file="chap4/results/p1_ACQ_7.2.RData")

#Parity 2
newdata2 <- aggregate(qualf3b_2 ~ age_2 + coh_2, FUN = function(x) c(y=length(x), q1=length(x[x==1]), q2=length(x[x==2]), q3=length(x[x==3])), subset=coh_2 <= 1982)
newdata2 <- data.frame(newdata2$age_2,newdata2$coh_2,newdata2$qualf3b_2)
colnames(newdata2) <- c("a","c","y","y1","y2","y3")
newdata2$a <- newdata2$a - median(17:44)
newdata2$yw <- aggregate(weights_2st ~ age_2 + coh_2, FUN = sum, subset = coh_2 <= 1982)$weights_2st
newdata2$wtmult <- newdata2$yw/newdata2$y
colnames(newdata2) <- c("a","c","y",paste0(1:3),"yw","wtmult")
qmax2 <- 3

newdata22 <- expand.grid(a=agerange,c=cohrange)
newdata22$cc <- newdata22$c-median(1945:1982)
newdata22$cind72 <- ifelse(newdata22$c >= 1972, 1972, newdata22$c)
newdata22$ccind72 <- newdata22$cind72-1944
newdata22$cc72 <- ifelse(newdata22$c<=1971,0,newdata22$cc)
newdata22$aind <- newdata22$a-14
newdata22$a <- newdata22$a-median(17:44)

newdata22 <- left_join(newdata22,newdata2,by=c("a","c"))
Nac <- nrow(newdata22)
y <- newdata22[,paste0(1:qmax2)]
yind <- which(!is.na(y[,1]))
Nobs <- length(yind)
y[is.na(y)] <- 0
cc72 <- newdata22$cc72
ccind72 <- newdata22$ccind72
Nc72 <- max(ccind72)
a <- newdata22$a
aind <- newdata22$aind
Na <- max(aind)
wt <- newdata22$wtmult
wt[is.na(wt)] <- 0

standata <- list(Nac=Nac, Nc72=Nc72, Na=Na, Nobs=Nobs, y=y, ccind72=ccind72, cc72=cc72, aind=aind, a=a, yind=yind, wt=wt)
stanout <- stan(file="chap4/stan/p2_ACQ_5.2.stan",data=standata,chains=1,iter=2000)
save(stanout,file="chap4/results/p2_ACQ_5.2.RData")

#Extract posterior mean probabilities for full ACQ surface
sumfunc <- function(files,newdata,qmax) {
  propsa <- list()
  for (i in 1:length(files)) {
    load(files[i])
    theta <- extract(stanout,pars="theta",permuted=F)
    propsb <- newdata[,paste(1:qmax)]
    propsb[,paste(1:qmax)] <- matrix(apply(theta[,1,],2,mean),nrow=nrow(propsb),byrow=F)
    propsa[[i]] <- as.matrix(propsb)
  }
  propsa
}

res0 <- sumfunc("chap4/results/p0_ACQ_5.2.RData",newdata02,qmax0)
res1 <- sumfunc("chap4/results/p1_ACQ_7.2.RData",newdata12,qmax1)
res2 <- sumfunc("chap4/results/p2_ACQ_5.2.RData",newdata22,qmax2)

#Create data frame for plotting
#Parity 0
ACQdat0 <- expand.grid(age=agerange,coh=cohrange,qualf=1:4)
ACQdat0 <- cbind(ACQdat0,thetam=as.vector(res0[[1]]))
ACQdat0$qualft <- factor(supp.labsq4[ACQdat0$qualf],levels=supp.labsq4)
ACQdat0$par <- 0

#Parity 1
ACQdat1 <- expand.grid(age=agerange,coh=cohrange,qualf=1:2)
ACQdat1 <- cbind(ACQdat1,thetam=as.vector(res1[[1]]))
ACQdat1$qualft <- factor(supp.labsq2b[ACQdat1$qualf],levels=supp.labsq2b)
ACQdat1$par <- 1

#Parity 2
ACQdat2 <- expand.grid(age=agerange,coh=cohrange,qualf=1:3)
ACQdat2 <- cbind(ACQdat2,thetam=as.vector(res2[[1]]))
ACQdat2$qualft <- factor(supp.labsq3b[ACQdat2$qualf],levels=supp.labsq3b)
ACQdat2$par <- 2

ACQdat <- rbind(ACQdat0,ACQdat1,ACQdat2)
ACQdat$qualft <- factor(ACQdat$qualft,levels=c("< GCSE","< A Level","GCSE","GCSE/A Level","A Level","< Degree","At least A Level","Degree"))

#Plot posterior mean probabilities (Figure 4.7)
png(file="chap4/plots/fig7.png",width=17.5,height=15,units="cm",res=400)
ggplot(ACQdat,aes(x=coh,y=age,fill=thetam)) +
  geom_raster(hjust=0.5,vjust=0.5) +
  labs(x = "Cohort", y = "Age", fill="Fitted\nprobability") +
  scale_fill_gradientn(colours=c("white",colorRampPalette(brewer.pal(9,"YlOrRd"))(10000),"black"), limits=c(0,1),guide = guide_colorbar(barheight = 20,frame.colour="black",ticks.colour="black"), breaks = seq(0,1,0.1)) +
  scale_x_continuous(expand=c(0,0),breaks=seq(1945,2000,5), minor_breaks=setdiff(1945:2003,seq(1945,2000,5))) + 
  scale_y_continuous(expand=c(0,0),breaks=seq(15,44,5), minor_breaks=setdiff(15:44,seq(15,44,5))) +
  coord_cartesian(ylim=c(14.5,44.5),xlim=c(1944.5,2003.5)) +
  geom_vline(xintercept=1945.5:2002.5,color="lightgray",size=0.25) +
  geom_hline(yintercept=15.5:43.5,color="lightgray",size=0.25) +
  geom_step(data=stepdat, mapping=aes(x=x,y=y,fill=x),direction="vh") +
  theme_bw() + theme(panel.grid.major = element_blank(),panel.grid.minor = element_blank(),axis.text.x = element_text(angle = 90, vjust=0.45),text = element_text("Calibri")) +
  facet_wrap(~par+qualft,nrow=3,labeller=labeller(par=supp.labsp))
dev.off()

#Create data frames with weighted counts
newdata0w2 <- newdata0
newdata0w2[,paste(1:qmax0)] <- newdata0w2[,paste(1:qmax0)] * newdata0w2$wtmult
newdata0w2$y <- newdata0w2$yw
newdata0w2$yw <- NULL

newdata1w2 <- newdata1
newdata1w2[,paste(1:qmax1)] <- newdata1w2[,paste(1:qmax1)] * newdata1w2$wtmult
newdata1w2$y <- newdata1w2$yw
newdata1w2$yw <- NULL

newdata2w2 <- newdata2
newdata2w2[,paste(1:qmax2)] <- newdata2w2[,paste(1:qmax2)] * newdata2w2$wtmult
newdata2w2$y <- newdata2w2$yw
newdata2w2$yw <- NULL

#Compute Pearson residuals and prepare for plotting
#Chi-square function
chifunc <- function(res,newdata,qmax,resid=F,ind) {
  chi <- if(resid) list() else numeric()
  for (i in 1:length(res)) {
    pred <- res[[i]][ind,]
    predc <- pred * newdata[,"y"]
    chires <- (newdata[,paste(1:qmax)]-predc)/sqrt(predc)
    if (resid) chi[[i]] <- as.matrix(chires)
    if (!resid) chi[i] <- sum(chires^2,na.rm=T)
  }
  chi
}

ACQtab0res <- ACQtab0
ACQtab0res$x[which(!is.na(ACQtab0res$x))] <- as.vector(chifunc(res0,newdata0w2,qmax0,T,which(!is.na(newdata02[,"1"])))[[1]])
ACQtab1res <- ACQtab1
ACQtab1res$x[which(!is.na(ACQtab1res$x))] <- as.vector(chifunc(res1,newdata1w2,qmax1,T,which(!is.na(newdata12[,"1"])))[[1]])
ACQtab2res <- ACQtab2
ACQtab2res$x[which(!is.na(ACQtab2res$x))] <- as.vector(chifunc(res2,newdata2w2,qmax2,T,which(!is.na(newdata22[,"1"])))[[1]])

ACQtabres <- rbind(ACQtab0res,ACQtab1res,ACQtab2res)
ACQtabres$qualft <- factor(ACQtabres$qualft,levels=c("< GCSE","< A Level","GCSE","GCSE/A Level","A Level","At least A Level","Degree"))
ACQtabres$Parity <- factor(ACQtabres$Parity,levels=c(0,1,2))
ACQtabres$x <- ifelse(abs(ACQtabres$x)>3,sign(ACQtabres$x)*3,ACQtabres$x)

#Plot Pearson residuals (Figure 4.8)
png(file="chap4/plots/fig8.png",width=17.5,height=15,units="cm",res=400)
ggplot(ACQtabres,aes(x=Group.2,y=Group.1,fill=x)) +
  geom_raster(hjust=0.5,vjust=0.5) +
  labs(x = "Cohort", y = "Age", fill="Pearson\nresidual") +
  scale_fill_gradientn(colors=rev(c("red","orange","yellow","white","green","turquoise","blue")), limits=c(-3,3),guide = guide_colorbar(barheight = 20,frame.colour="black",ticks.colour="black"), breaks = seq(-3,3,0.5), labels=c(expression(""<="-3"),"-2.5","-2.0","-1.5","-1.0","-0.5","0.0","0.5","1.0","1.5","2.0","2.5",expression("">=3.0))) +
  scale_x_continuous(expand=c(0,0),breaks=seq(1945,1990,5), minor_breaks=setdiff(1945:1992,seq(1945,1990,5))) + 
  scale_y_continuous(expand=c(0,0),breaks=seq(15,44,5), minor_breaks=setdiff(15:44,seq(15,44,5))) +
  coord_cartesian(ylim=c(14.5,44.5),xlim=c(1944.5,1982.5)) +
  geom_vline(xintercept=1945.5:1981.5,color="lightgray",size=0.25) +
  geom_hline(yintercept=15.5:43.5,color="lightgray",size=0.25) +
  geom_step(data=stepdat, mapping=aes(x=x,y=y),direction="vh") +
  theme_bw() + theme(panel.grid.major = element_blank(),panel.grid.minor = element_blank(),axis.text.x = element_text(angle = 90, vjust=0.45),text = element_text("Calibri")) +
  facet_wrap(~Parity+qualft,labeller=labeller(Parity=supp.labsp),nrow=3)
dev.off()
