
setClass("MULTINOMresponse",contains="GLMresponse")

setMethod("fit","MULTINOMresponse",
	function(object,w) {
		if(missing(w)) w <- NULL
		pars <- object@parameters
		base <- object@family$base # delete me
		y <- object@y
		x <- object@x
		#if(is.null(w)) w <- rep(1,nrow(y))
		# mask is an nx*ny matrix (x are inputs, y are output levels)
		mask <- matrix(1,nrow=nrow(pars$coefficients),ncol=ncol(pars$coefficients))
		mask[,base] <- 0 # fix base category coefficients to 0
		mask <- rbind(0,mask) # fix automatic "bias" nodes to 0
		Wts <- mask
		Wts[-1,] <- pars$coefficients # set starting weights
		if(!is.null(w)) {
			if(NCOL(y) < 3) {
				fit <- nnet.default(x,y,weights=w,size=0,entropy=TRUE,skip=TRUE,mask=mask,Wts=Wts,trace=FALSE)
			} else {
				fit <- nnet.default(x,y,weights=w,size=0,softmax=TRUE,skip=TRUE,mask=mask,Wts=Wts,trace=FALSE)
			}
		} else {
			if(NCOL(y) < 3) {
				fit <- nnet.default(x,y,size=0,entropy=TRUE,skip=TRUE,mask=mask,Wts=Wts,trace=FALSE)
			} else {
				fit <- nnet.default(x,y,size=0,softmax=TRUE,skip=TRUE,mask=mask,Wts=Wts,trace=FALSE)
			}
		}
		# this is necessary because setpars wants coefficients in column major order
		pars$coefficients <- t(matrix(fit$wts,ncol=ncol(pars$coefficients),nrow=nrow(pars$coefficients)+1)[-1,])
		# parameters are set correctly now
		object <- setpars(object,unlist(pars))
		object
	}
)

setMethod("logDens","MULTINOMresponse",
	function(object) {
		log(rowSums(object@y*predict(object)))
	}
)

setMethod("dens","MULTINOMresponse",
	function(object,log=FALSE) {
		if(log) log(rowSums(object@y*predict(object)))
		else rowSums(object@y*predict(object))
	}
)

setMethod("predict","MULTINOMresponse",
	function(object) {
		if(object@family$link=="identity") object@x%*%object@parameters$coefficients
		else {
			object@family$linkinv(object@x%*%object@parameters$coefficients,base=object@family$base)
		}
	}
)

setMethod("simulate",signature(object="MULTINOMresponse"),
	function(object,nsim=1,seed=NULL,times) {
		if(!is.null(seed)) set.seed(seed)
		if(missing(times)) {
			# draw all times in one go
			pr <- predict(object)
			n <- rowSums(object@y)
		} else {
			pr <- predict(object)[times,]
			n <- rowSums(object@y)[times]
			if(length(times)==1) pr <- matrix(pr,ncol=length(pr))
		}
		nt <- nrow(pr)
		sims <- array(apply(pr,1,rmultinom,n=nsim,size=n),dim=c(ncol(pr),nsim,nt))
		sims <- matrix(aperm(sims,c(3,2,1)),nrow=nsim*nt,ncol=ncol(pr))
		#response <- t(apply(sims,c(2,3), function(x) which(x==1)))
		return(sims)
	}
)
