# Author: Robert J. Hijmans
# Date :  August 2009
# Version 0.9
# License GPL v3

parfun <- function(cls, d, fun, model, ...) {
	nr <- nrow(d)
	nc <- length(cls)
	s <- split(d, rep(1:nc, each=ceiling(nr/nc), length.out=nr))
	p <- parallel::clusterApply(cls, s, function(i, ...) fun(model, i, ...), ...)
	if (!is.null(dim(p[[1]]))) {
		do.call(rbind, p)
	} else {
		unlist(p)
	}
}


.runModel <- function(model, fun, d, nl, const, na.rm, index, cores=1, cls=NULL, ...) {
	if (!is.data.frame(d)) {
		d <- data.frame(d)
	}
	if (! is.null(const)) {
		for (i in 1:ncol(const)) {
			d <- cbind(d, const[,i,drop=FALSE])
		}
	}	
	if (na.rm) {
		n <- nrow(d)
		i <- rowSums(is.na(d)) == 0
		d <- d[i,,drop=FALSE]
		if (nrow(d) > 0) {
			if (cores > 1) {
				r <- parfun(cls, d, fun, model, ...)
			} else {
				r <- fun(model, d, ...)
			}
			if (is.list(r)) {
				r <- as.data.frame(lapply(r, as.numeric))			
			} else if (is.factor(r)) {
				r <- as.integer(r)
			} else if (is.data.frame(r)) {
				r <- sapply(r, as.numeric)
			}
			#how could it not be numeric?
			#else if (is.data.frame(r)) {
			#	if (nrow(r) > 1) {
			#		r <- apply(r, as.numeric)
			#	} else {
			#		r[] <- as.numeric(r)
			#	}
			#}
			r <- as.matrix(r)
			if (!all(i)) {
				m <- matrix(NA, nrow=nl*n, ncol=ncol(r))
				m[i,] <- r
				colnames(m) <- colnames(r)
				r <- m
			}
		} else {
			if (!is.null(index)) {
				r <- matrix(NA, nrow=nl*n, ncol=max(index))
			} else {
				r <- matrix(NA, nrow=nl*n, ncol=1)
			}
		}
	} else {
		if (cores > 1) {
			r <- parfun(cls, d, fun, model, ...)
		} else {
			r <- fun(model, d, ...)
		}
		if (is.list(r)) {
			r <- as.data.frame(lapply(r, as.numeric))			
		} else if (is.factor(r)) {
			r <- as.integer(r)
		} else if (is.data.frame(r)) {
			r <- sapply(r, as.numeric)
		}
		r <- as.matrix(r)
	}
	if (inherits(model, "gstat")) {
		nr <- max(nrow(d), 5)
		xy <- as.matrix(d[1:nr,1:2])
		if (all(xy == r[1:nr, 1:2])) {
			r <- r[,-c(1:2)]   # x, y
		}
	}
	if (!is.null(index)) {
		r <- r[, index,drop=FALSE]
	}
	r
}


.getFactors <- function(model, fun, d, nl, const, na.rm, index, ...) {
	if (!is.data.frame(d)) {
		d <- data.frame(d)
	}
	if (! is.null(const)) {
		for (i in 1:ncol(const)) {
			d <- cbind(d, const[,i,drop=FALSE])
		}
	}	
	if (na.rm) {
		n <- nrow(d)
		i <- rowSums(is.na(d)) == 0
		d <- d[i,,drop=FALSE]
	}
	if (nrow(d) > 0) {
		r <- fun(model, d, ...)
	}

	if (inherits(model, "gstat")) {
		nr <- max(nrow(d), 5)
		xy <- d[1:nr,1:2]
		if (all(xy == r[1:nr, 1:2])) {
			r <- r[,-c(1:2)]   # x, y
		}
	}

	if (is.list(r) || is.data.frame(r)) {
		out <- sapply(r, levels)
		for (i in 1:length(out)) {
			if (!is.null(out[[i]])) {
				out[[i]] <- data.frame(value=1:length(out[[i]]), label=out[[i]])
			}
		}
		out
	} else {
		NULL
	}
}

setMethod("predict", signature(object="SpatRaster"),
	function(object, model, fun=predict, ..., factors=NULL, const=NULL, na.rm=FALSE, index=NULL, cores=1, cpkgs=NULL, filename="", overwrite=FALSE, wopt=list()) {

		nms <- names(object)
		if (length(unique(nms)) != length(nms)) {
			tab <- table(nms)
			error("predict", "duplicate names in SpatRaster: ", tab[tab>1])
		}

		#factors should come with the SpatRaster
		#haveFactor <- FALSE
		#if (!is.null(factors)) {
		#	factors <- .getFactors(model, factors, nms)
		#	fnames <- names(f)
		#	haveFactor <- TRUE
		#}

		nl <- 1
		nc <- ncol(object)
		nr <- nrow(object)
		tomat <- FALSE
		readStart(object)
		on.exit(readStop(object))

		testrow <- round(0.51*nr)
		rnr <- 1
		if (nc==1) rnr <- min(nr, 20) - testrow + 1
		d <- readValues(object, testrow, rnr, 1, nc, TRUE, TRUE)
		cn <- NULL
		if (!is.null(index)) {
			nl <- length(index)
		} else {
			allna <- FALSE
			if (na.rm) {
				allna <- all(is.na(d))
				if (allna) {
					testrow <- ceiling(testrow - 0.25*nr)
					d <- readValues(object, testrow, rnr, 1, nc, TRUE, TRUE)
					allna <- all(is.na(d))
				}
				if (allna) {
					testrow <- floor(testrow + 0.5*nr)
					if ((testrow + rnr) > nr) rnr = nr - testrow + 1
					d <- readValues(object, testrow, rnr, 1, nc, TRUE, TRUE)
					allna <- all(is.na(d))
				}
				if (allna && (ncell(object) < 1000)) {
					d <- readValues(object, 1, nr, 1, nc, TRUE, TRUE)
					allna <- all(is.na(d))
					#if (allna) {
					#	error("predict", "all predictor values are NA")
					#}
				}
				if (allna) {
					d <- spatSample(object, min(1000, ncell(object)), "regular")
					allna <- all(is.na(d))
				}
			}
			if (!allna) {
				r <- .runModel(model, fun, d, nl, const, na.rm, index, ...)
				if (ncell(object) > 1) {
					nl <- ncol(r)
					cn <- colnames(r)
				} else {
					nl <- length(r)
				}
				levs <- .getFactors(model, fun, d, nl, const, na.rm, index, ...)
			} else {
				warn("predict", "Cannot determine the number of output variables. Assuming 1. Use argument 'index' to set it manually")
				levs <- NULL
			}
		}
		out <- rast(object, nlyrs=nl)
		levels(out) <- levs
		if (length(cn) == nl) names(out) <- make.names(cn, TRUE)

		if (cores > 1) {
			cls <- parallel::makeCluster(cores)
			on.exit(parallel::stopCluster(cls), add=TRUE)
			parallel::clusterExport(cls, c("model", "fun"), environment())
			if (!is.null(cpkgs)) {
				parallel::clusterExport(cls, "cpkgs", environment())
				parallel::clusterCall(cls, function() for (i in 1:length(cpkgs)) {library(cpkgs[i], character.only=TRUE) })
			}
			dots <- list(...)
			if (length(dots) > 0) {
				nms <- names(dots)
				dotsenv <- new.env()
				lapply(1:length(dots), function(i) assign(nms[i], dots[[i]], envir=dotsenv))
				parallel::clusterExport(cls, nms, dotsenv)
			}
		} else {
			cls <- NULL
		}
		b <- writeStart(out, filename, overwrite, wopt=wopt, n=max(nlyr(out), nlyr(object))*4)
		for (i in 1:b$n) {
			d <- readValues(object, b$row[i], b$nrows[i], 1, nc, TRUE, TRUE)
			r <- .runModel(model, fun, d, nl, const, na.rm, index, cores=cores, cls=cls, ...)
			writeValues(out, r, b$row[i], b$nrows[i])
		}
		writeStop(out)
		return(out)
	}
)

