# file nnet/knn.q copyright (C) 1994-9 W. N. Venables and B. D. Ripley
#
knn1 <- function(train, test, cl)
{
	train <- as.matrix(train)
	if(is.null(dim(test))) dim(test) <- c(1, length(test))
	test <- as.matrix(test)
	p <- ncol(train)
	ntr <- nrow(train)
	if(length(cl) != ntr) stop("train and class have different lengths")
	nte <- nrow(test)
	if(ncol(test) != p) stop("Dims of test and train differ")
	clf <- as.factor(cl)
	nc <- max(unclass(clf))
	res <- .C("VR_knn1",
		as.integer(ntr),
		as.integer(nte),
		as.integer(p),
		as.double(train),
		as.integer(unclass(clf)),
		as.double(test),
		res = integer(nte),
		integer(nc+1),
		as.integer(nc),
		d = double(nte)
		)$res
	factor(res, levels=seq(along=levels(clf)), labels=levels(clf)) 
}

knn <- function(train, test, cl, k=1, l=0, prob=FALSE, use.all=TRUE)
{
	train <- as.matrix(train)
	if(is.null(dim(test))) dim(test) <- c(1, length(test))
	test <- as.matrix(test)
	p <- ncol(train)
	ntr <- nrow(train)
	if(length(cl) != ntr) stop("train and class have different lengths")
	if(ntr < k) {
	   warning(paste("k =",k,"exceeds number",ntr,"of patterns"))
	   k <- ntr
	}
	if (k < 1) stop(paste("k =",k,"must be at least 1"))
	nte <- nrow(test)
	if(ncol(test) != p) stop("Dims of test and train differ")
	clf <- as.factor(cl)
	nc <- max(unclass(clf))
	Z <- .C("VR_knn",
		as.integer(k),
		as.integer(l),
		as.integer(ntr),
		as.integer(nte),
		as.integer(p),
		as.double(train),
		as.integer(unclass(clf)),
		as.double(test),
		res = integer(nte),
		pr = double(nte),
		integer(nc+1),
		as.integer(nc),
		as.integer(FALSE),
		as.integer(use.all)
		)
	res <- factor(Z$res, levels=seq(along=levels(clf)),labels=levels(clf))
	if(prob) attr(res, "prob") <- Z$pr
	res
}

knn.cv <- function(train, cl, k=1, l=0, prob=FALSE, use.all=TRUE)
{
	train <- as.matrix(train)
	p <- ncol(train)
	ntr <- nrow(train)
	if(ntr-1 < k) {
	   warning(paste("k =",k,"exceeds number",ntr-1,"of patterns"))
	   k <- ntr - 1
	}
	if (k < 1) stop(paste("k =",k,"must be at least 1"))
	clf <- as.factor(cl)
	nc <- max(unclass(clf))
	Z <- .C("VR_knn",
		as.integer(k),
		as.integer(l),
		as.integer(ntr),
		as.integer(ntr),
		as.integer(p),
		as.double(train),
		as.integer(unclass(clf)),
		as.double(train),
		res = integer(ntr),
		pr = double(ntr),
		integer(nc+1),
		as.integer(nc),
		as.integer(TRUE),
		as.integer(use.all)
		)
	res <- factor(Z$res, levels=seq(along=levels(clf)),labels=levels(clf))
	if(prob) attr(res, "prob") <- Z$pr
	res
}
# file nnet/lvq.q copyright (C) 1994-9 W. N. Venables and B. D. Ripley
#
lvqinit <- function(x, cl, size, prior, k=5) {
	x <- as.matrix(x)
	n <- nrow(x)
	p <- ncol(x)
	if(length(cl) != n) stop("x and cl have different lengths")
	g <- as.factor(cl)
	counts <- tapply(rep(1, length(g)), g, sum)
	prop <- counts/n
	np <- length(prop)
# allow for supplied prior
	if(missing(prior)) prior <- prop
	else if(any(prior <0)||round(sum(prior), 5) != 1) 
			stop("invalid prior")
	if(length(prior) != np) stop("prior is of incorrect length")
	if(missing(size)) size <- min(round(0.4 * np * (np-1+p/2),0), n)
	inside <- knn.cv(x, cl, k) == cl
	selected <- numeric(0)
	for(i in 1:np){
		set <- seq(along=g)[unclass(g)==i & inside]
		if(length(set) > 1) 
		   set <- sample(set, min(length(set), round(size*prior[i])))
		selected <- c(selected, set)
	}
	list(x = x[selected, , drop=FALSE], cl = cl[selected])
}

olvq1 <- function(x, cl, codebk, niter = 40*nrow(codebk$x), alpha = 0.3) {
	x <- as.matrix(x)
	n <- nrow(x)
	p <- ncol(x)
	nc <- dim(codebk$x)[1]
	if(length(cl) != n) stop("x and cl have different lengths")
	iters <- sample(n, niter, TRUE)
	z <- .C("VR_olvq",
		as.double(alpha),
		as.integer(n),
		as.integer(p),
		as.double(x),
		as.integer(unclass(cl)),
		as.integer(nc),
		xc = as.double(codebk$x),
		as.integer(codebk$cl),
		as.integer(niter),
		as.integer(iters-1)
		)
	xc <- matrix(z$xc,nc,p)
	dimnames(xc) <- dimnames(codebk$x)
	list(x = xc, cl = codebk$cl)
}

lvq1 <- function(x, cl, codebk, niter = 100*nrow(codebk$x), alpha = 0.03) {
	x <- as.matrix(x)
	n <- nrow(x)
	p <- ncol(x)
	nc <- dim(codebk$x)[1]
	if(length(cl) != n) stop("x and cl have different lengths")
	iters <- sample(n, niter, TRUE)
	z <- .C("VR_lvq1",
		as.double(alpha),
		as.integer(n),
		as.integer(p),
		as.double(x),
		as.integer(unclass(cl)),
		as.integer(nc),
		xc = as.double(codebk$x),
		as.integer(codebk$cl),
		as.integer(niter),
		as.integer(iters-1)
		)
	xc <- matrix(z$xc,nc,p)
	dimnames(xc) <- dimnames(codebk$x)
	list(x = xc, cl = codebk$cl)
}

lvq2 <- function(x, cl, codebk, niter = 100*nrow(codebk$x), alpha = 0.03, win = 0.3) {
	x <- as.matrix(x)
	n <- nrow(x)
	p <- ncol(x)
	nc <- dim(codebk$x)[1]
	if(length(cl) != n) stop("x and cl have different lengths")
	iters <- sample(n, niter, TRUE)
	z <- .C("VR_lvq2",
		as.double(alpha),
		as.double(win),
		as.integer(n),
		as.integer(p),
		as.double(x),
		as.integer(unclass(cl)),
		as.integer(nc),
		xc = as.double(codebk$x),
		as.integer(codebk$cl),
		as.integer(niter),
		as.integer(iters-1)
		)
	xc <- matrix(z$xc,nc,p)
	dimnames(xc) <- dimnames(codebk$x)
	list(x = xc, cl = codebk$cl)
}

lvq3 <- function(x, cl, codebk, niter = 100*nrow(codebk$x), alpha = 0.03, win = 0.3, epsilon = 0.1) {
	x <- as.matrix(x)
	n <- nrow(x)
	p <- ncol(x)
	nc <- dim(codebk$x)[1]
	if(length(cl) != n) stop("x and cl have different lengths")
	iters <- sample(n, niter, TRUE)
	z <- .C("VR_lvq3",
		as.double(alpha),
		as.double(win),
		as.double(epsilon),
		as.integer(n),
		as.integer(p),
		as.double(x),
		as.integer(unclass(cl)),
		as.integer(nc),
		xc = as.double(codebk$x),
		as.integer(codebk$cl),
		as.integer(niter),
		as.integer(iters-1)
		)
	xc <- matrix(z$xc,nc,p)
	dimnames(xc) <- dimnames(codebk$x)
	list(x = xc, cl = codebk$cl)
}

lvqtest <- function(codebk, test) knn1(codebk$x, test, codebk$cl)
# file nnet/multiedit.q copyright (C) 1994-9 W. N. Venables and B. D. Ripley
#
multiedit <- function(x, class, k=1, V=3, I=5, trace=TRUE)
{
     n1 <- length(class)
     class <- codes(class)
     index <- 1:n1
     pass <- lpass <- 0
     repeat{
         if(n1 < 5*V) {
             warning("retained set is now too small to proceed")
             break
         }
	 pass <- pass + 1
	 sub <- sample(V, length(class), replace=TRUE)
	 keep <- logical(length(class))
	 for (i in 1:V){
	     train <- sub==i
	     test <- sub==(1 + i%%V)
	     keep[test] <- (knn(x[train, , drop=FALSE], x[test, , drop=FALSE], 
		class[train],k) == class[test])
	 }
	 x <- x[keep, , drop=FALSE]; class <- class[keep]; index <- index[keep]
	 n2 <- length(class)
	 if(n2 < n1) lpass <- pass
	 if(lpass <= pass - I) break
	 n1 <- n2
	 if(trace) cat(paste("pass ", pass," size ", n2, "\n"))
     }
     index
}

condense <- function(train, class, store=sample(seq(n), 1), trace=TRUE)
{
     n <- length(class)
     bag <- rep(TRUE, n)
     bag[store] <- FALSE
     repeat {
        if(trace) print(seq(n)[!bag])
        if(sum(bag) == 0) break
        res <- knn1(train[!bag,,drop = FALSE], train[bag,,drop = FALSE], class[!bag])
        add <- res != class[bag]
        if(sum(add) == 0) break
        cand <- (seq(n)[bag])[add]
	if(length(cand) > 1) cand <- sample(cand, 1)
        bag[cand] <- FALSE
     }
     seq(n)[!bag]
}

reduce.nn <- function(train, ind, class)
{
     n <- length(class)
     rest <- seq(n)[-ind]
# this must be done iteratively, not simultaneously
     for(i in sample(ind)) {
	 res <- knn1(train[-c(rest,i),,drop=FALSE], train[c(rest,i),,drop=FALSE], 
	             class[-c(rest,i)])
	 if(all(res == class[c(rest,i)])) rest <- c(rest,i)
     }
     seq(n)[-rest]
}

.First.lib <- function(lib, pkg) library.dynam("class", pkg, lib)
