## Copyright (C) 1997  Friedrich Leisch
## 
## This program is free software; you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation; either version 2, or (at your option)
## any later version.
## 
## This program is distributed in the hope that it will be useful, but
## WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
## General Public License for more details. 
## 
## You should have received a copy of the GNU General Public License
## along with this file.  If not, write to the Free Software Foundation,
## 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.

## Author:  FL (Friedrich.Leisch@ci.tuwien.ac.at)

library("e1071")

snns.read.result <- function(file)
{
  tmpf <- tmpfile();
  res2r <- system.file("cmd", "res2r")
  system(paste(res2r, file, " > ", tmpf))

  z<-nn.read.tmpfile(tmpf)
  system(paste("rm -f", tmpf))
  z
}

snns.read.nn.pattern <- function(file)
{
  tmpf <- tmpfile();
  pat2r <- system.file("cmd", "pat2r")
  system(paste(pat2r, file, " > ", tmpf))

  z<-nn.read.tmpfile(tmpf)
  system(paste("rm -f", tmpf))
  z
}


nn.read.tmpfile <- function(file)
{
  n <- scan(file=file,nlines=1,quiet=TRUE)
  x <- scan(file=file,skip=1,quiet=TRUE)

  r <- length(x)/sum(n)

  x <- matrix(x, nrow=r, byrow=TRUE)

  inp <- NULL
  tar <- NULL
  out <- NULL
  
  if(n[1]>0){
    inp <- x[,1:n[1]]
  }
  if(n[2]>0){
    tar <- x[,n[1]+(1:n[2])]
  }
  if(n[3]>0){
    out <- x[,n[1]+n[2]+(1:n[3])]
  }

  nn.pattern(input=inp, target=tar, output=out)
}

snns.write <- function (x, ...)
  UseMethod("snns.write")

snns.write.nn.pattern <- function(pattern, file, wtarget=TRUE)
{
  
  if(is.null(pattern$input)){
    stop("Empty matrix of inputs")
  }
  else{
    pattern$input <- as.matrix(pattern$input)
    ninput <- nrow(pattern$input)
    dinput <- ncol(pattern$input)
  }
    
  if(is.null(pattern$target)){
    wtarget <- FALSE
    dtarget <- 0
  }
  else{
    pattern$target <- as.matrix(pattern$target)
    ntarget <- nrow(pattern$target)
    dtarget <- ncol(pattern$target)
  }
  

  if(wtarget){
    if(ninput != ntarget){
      stop("Input and target must have the same number of rows")
    }
  }
  
  system(paste("mkhead ", ninput, dinput, dtarget, " > ", file))
  
  for(r in 1:ninput){
    write(paste("# Input pattern ", r, ":", sep=""),
	  file=file, append=T)
    write(pattern$input[r,], file=file, append=T)
    if(wtarget){
      write(paste("# Output pattern ", r, ":", sep=""),
	    file=file, append=T)
      write(pattern$target[r,], file=file, append=T)
    }
  }
  
}


nn.pattern <- function(input, target=NULL, output=NULL, code=NULL,
		       target.labels=NULL, output.labels=NULL)
{
  z<-list(input  = as.matrix(input),
	  target = as.matrix(target),
	  output = as.matrix(output),
	  code  = as.matrix(code),
	  target.labels = target.labels,
	  output.labels = output.labels)
  class(z) <- "nn.pattern"
  z
}

find.classes <- function(x,...)
  UseMethod("find.classes")

find.classes.default <- function(x, byname=FALSE, method="WTA",
			     low=0, high=0, code=NULL){

  w<-rep(NA,nrow(x))
  if(is.null(code)){
    n <- colnames(x)
  }
  else{
    n <- rownames(code)
  }

  if(method == "WTA"){
    for(k in 1:nrow(x)){
      if(max(x[k,]) >= low){
	w[k] <- match(max(x[k,]), x[k,])
      }
    }
  }
  else if(method == "euclid"){
    if(is.null(code)){
      stop("No codes given for method euclid");
    }
    for(k in 1:nrow(x)){
      tmp <- matrix(x[k,],ncol=ncol(x),nrow=nrow(code),byrow=TRUE)
      tmp <- apply((tmp - code)^2,1, sum)
      w[k] <- match(min(tmp), tmp)
    }
  }
  
  z <- w
  if(byname){
    for(k in 1:nrow(x)){
      if(!is.na(w[k])){
	z[k] <- n[w[k]]
      }
    }
  }
  z
}

find.classes.nn.pattern <- function(p, byname=FALSE, method="WTA",
				low=0, high=0){
  p$target.labels <- p$target
  p$output.labels  <- p$output
  if(ncol(p$target) > 1){
    p$target.labels <- find.classes(p$target, byname=byname,
				  method=method, code=p$code)
  }
  if(ncol(p$output) > 1){
    p$output.labels <- find.classes(p$output, byname=byname,
				  method=method, code=p$code)
  }
  p
}


summary.nn.pattern <- function(p, err=TRUE, classerr=TRUE){

  cat("Input         :", dim(p$input),"\n")
  cat("Target        :", dim(p$target),"\n")
  cat("Output        :", dim(p$output),"\n")
  cat("Target Labels :", dim(p$target.labels),"\n")
  cat("Output Labels :", dim(p$output.labels),"\n")
  cat("Code          :", dim(p$code), "\n")
  cat("\n")
  
  if(!is.null(p$target) && !is.null(p$output) && err){
    print(error(p))
  }
  
  if(!is.null(p$target.labels) && !is.null(p$output.labels) &&
     classerr){
  
    y <- (p$target.labels==p$output.labels)
    ok <- complete.cases(y)
    
    s <- c(sum(y[ok])/length(y),
	   sum(!y[ok])/length(y),
	   sum(is.na(y))/length(y))
  
    names(s) <- c("Right", "Wrong", "Unclassified")
    print(s)
  }
}

error <- function(x,...)
  UseMethod("error")

error.nn.pattern <- function (p) {
  z<-(p$target - p$output)^2
  z<-apply (z, 1, sum)
  n<-length(z)
  z<-sum(z)
  z<-c(z, z/n)
  names(z) <- c("SSE", "MSE")
  z
}

one.in.n <- function(x, ...)
  UseMethod("one.in.n")

one.in.n.default <- function(x, classes=sort(unique(x)))
{
  nx <- length(x)
  nc <- length(classes)
  
  tmp <- matrix(classes, nx, nc, byrow=T)
  z <- (matrix(x, nx, nc) == tmp)
  z <- matrix(as.integer(z), nx, nc)
  colnames(z) <- as.character(classes)
  z
}

encode <- function(x, ...)
  UseMethod("encode")

encode.default <- function(x, code=code)
{
  classes <- sort(unique(x))
  z <- rep(0,length=length(x))
  for(k in 1:length(classes)){
    z[x==classes[k]] <- k
  } 
  code[z,]
}

xor.bench <- function(n, d=2){

  input <- matrix(runif(n*d,-1,1),ncol=d,nrow=n)
  if((d != as.integer(d)) || (d<2))
    stop("d must be an integer >=2")

  z <- rep(0, length=n)
  for(k in 1:n){
    if(input[k,1]>=0){
      tmp <- (input[k,2:d] >=0)
      z[k] <- 1+sum(tmp*2^(0:(d-2)))
    }
    else {
      tmp <- !(input[k,2:d] >=0)
      z[k] <- 1 + sum(tmp*2^(0:(d-2)))
    }
  }

  nn.pattern(input=input, target.labels=z, target=one.in.n(z))
}


normalize <- function(x) {
  (x-mean(x))/sqrt(var(x))
}
  

resample <- function(x,...)
  UseMethod("resample")

resample.nn.pattern <- function(p, n, prob=rep(1/nrow(p$input)))
{
  
  np <- length(prob)
  prob <- cumsum(prob/sum(prob))
  
  x <- matrix(0, nrow=n, ncol=ncol(p$input))
  y<-NULL
  z<-NULL
  
  if(!is.null(p$target)){
    y <- matrix(0, nrow=n, ncol=ncol(p$target))
  }
  if(!is.null(p$target.labels)){
    z <- rep(0, length=n)
  }
    
  for(k in 1:n){
    e <- runif(1)
    
    ## find the smallest index where e<=p
    l <- (1:np)[(e<=prob)]
    l <- l[1]
    
    x[k,] <- p$input[l,]
    if(!is.null(p$target)){
      y[k,] <- p$target[l,]
    }
    if(!is.null(p$target.labels)){
      z[k] <- p$target.labels[l]
    }
  }
  
  nn.pattern(input=x, target=y, target.labels=z, code=p$code)
}




