/*
 * Copyright (c) 2001-2003 The Trustees of Indiana University.  
 *                         All rights reserved.
 * Copyright (c) 1998-2001 University of Notre Dame. 
 *                         All rights reserved.
 * Copyright (c) 1994-1998 The Ohio State University.  
 *                         All rights reserved.
 * 
 * This file is part of the LAM/MPI software package.  For license
 * information, see the LICENSE file in the top level directory of the
 * LAM/MPI source distribution.
 * 
 * $HEADER$
 *
 * $Id: ssi_coll_lam_basic_reduce_scatter.c,v 1.2 2003/05/28 00:16:28 jsquyres Exp $
 *
 *	Function:	- Basic collective routines
 */

#include <lam_config.h>
#if LAM_WANT_PROFILE
#define LAM_PROFILELIB 1
#endif
#include <lam-ssi-coll-lam-basic-config.h>

#include <lam-ssi-coll.h>
#include <lam-ssi-coll-lam-basic.h>
#include <mpi.h>
#include <blktype.h>
#include <mpisys.h>


/*
 *	reduce_scatter
 *
 *	Function:	- reduce then scatter
 *	Accepts:	- same as MPI_Reduce_scatter()
 *	Returns:	- MPI_SUCCESS or error code
 */
int
lam_ssi_coll_lam_basic_reduce_scatter(void *sbuf, void *rbuf, int *rcounts,
				      MPI_Datatype dtype, MPI_Op op,
				      MPI_Comm comm)
{
  int i;
  int err;
  int rank;
  int size;
  int count;
  int *disps = 0;
  char *buffer = 0;
  char *origin = 0;

  MPI_Comm_size(comm, &size);
  MPI_Comm_rank(comm, &rank);

  /* Initialize reduce & scatterv info at the root (rank 0). */

  for (i = 0, count = 0; i < size; ++i) {
    if (rcounts[i] < 0) {
      return EINVAL;
    }
    count += rcounts[i];
  }

  if (rank == 0) {
    disps = (int *) malloc((unsigned) size * sizeof(int));
    if (disps == 0) {
      free((char *) disps);
      return errno;
    }

    err = lam_dtbuffer(dtype, count, &buffer, &origin);
    if (err != MPI_SUCCESS) {
      free((char *) disps);
      return err;
    }

    disps[0] = 0;
    for (i = 0; i < (size - 1); ++i)
      disps[i + 1] = disps[i] + rcounts[i];
  }

  /* reduction */

  err = MPI_Reduce(sbuf, origin, count, dtype, op, 0, comm);
  if (err != MPI_SUCCESS) {
    if (disps)
      free((char *) disps);
    if (buffer)
      free(buffer);
    return err;
  }

  /* scatter */

  err = MPI_Scatterv(origin, rcounts, disps, dtype,
		     rbuf, rcounts[rank], dtype, 0, comm);
  if (disps)
    free((char *) disps);
  if (buffer)
    free(buffer);
  return err;
}
