#!/usr/bin/env python3
#
# This python script demonstrates the greedy groups and greedy groups refined clustering algorithms.
# Both algorithms are designed for affinity matrixes that list pairwise ingroup probabilities, i.e. the probability that any two units are in the same group/cluster.
#
# Tested using Python 3.11, Mac OS 13.6.5
#
# Stilianos Louca
# Copyright 2024
# 
# LICENSE AGREEMENT
# - - - - - - - - -
# All rights reserved.
# Use and redistributions of this code is permitted for commercial and non-commercial purposes,
# under the following conditions:
#
#	* Redistributions must retain the above copyright notice, this list of 
#	  conditions and the following disclaimer in the code itself, as well 
#	  as in documentation and/or other materials provided with the code.
#	* Neither the name of the original author (Stilianos Louca), nor the names 
#	  of its contributors may be used to endorse or promote products derived 
#	  from this code without specific prior written permission.
#	* Proper attribution must be given to the original author, including a 
#     reference to the peer-reviewed publication through which the code was published.
#
# THIS CODE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS "AS IS" AND ANY 
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 
# OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 
# IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 
# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS CODE, 
# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# - - - - - - - - -


from matplotlib import pyplot
import numpy
from numpy import NaN
RNG = numpy.random.default_rng(1233) # instantiate a random number generator


##################################
# FUNCTION DEFINITIONS

# given a 1D numeric numpy array, determine the index of the largest non-nan value
def get_largest_item(values, check_nan=True):
	if(check_nan):
		values = numpy.asarray(values)
		valids = numpy.where(~numpy.isnan(values))[0]
		if(len(valids)==0): return -1
		largest = valids[numpy.argmax(values[valids])]
	elif(len(values)==0):
		return -1
	else:
		largest = numpy.argmax(values)
	return largest


# adjust the mean of a set of numbers, for the case where some value are omitted
# Xomitted[] should be a sequence (e.g. list or tuple) of values omitted from the set over which the average is requested
def mean_without(N, Xmean, Xomitted):
	M = len(Xomitted)
	if(N<=M): return NaN
	elif(M==0): return Xmean
	else: return (Xmean - numpy.sum(Xomitted)/N)*(N/(N-M))


# given a list of NG integer sets ("groups"), with each set containing non-negative integers from 0 to Nmembers-1 ("members"), construct a corresponding list of length Nmembers mapping each member to the groups containing it.
# groups[g] should be a 1D integer array or integer list defining the members of group g.
# Note that each member could belong to multiple groups, and some groups may not have any members.
# If Nmembers is None, it is automatically inferred from the members of the provided groups.
# If unique==False, then the returned member2groups will be a list of length Nmembers, with member2groups[m] being an integer list (with values in 0,..,NG-1) containing the indices of the groups that containing member m, i.e. each member2groups[m][k] will be such that groups[member2groups[m][k]] contained m.
# If unique==True, then the caller guarantees that each member belongs to exactly one group. In that case the returned member2group will be a 1D integer array of length Nmembers, with member2groups[m] being an integer in 0,..,NG-1, specifying the unique group that contained member m.
# This function is essentially the inverse of group_integers().
def map_members_to_groups(groups, Nmembers=None, unique=False):
	if(Nmembers is None): Nmembers = 1+max((max(group) if (len(group)>0) else 0) for group in groups)
	if(unique):
		member2group = numpy.full(Nmembers,-1)
		for g,group in enumerate(groups):
			member2group[group] = g
		return member2group
	else:
		member2groups = [[] for m in range(Nmembers)]
		for g,group in enumerate(groups):
			for m in group:
				member2groups[m].append(g)
		return member2groups


# given a square affinity matrix listing ingroup-probabilities, cluster units into groups using the Greedy Groups algorithm
# The algorithm assumes that A is symmetric, all entries are between 0 and 1, and all diagonal entries are 1.
def greedy_groups(A, affinity_threshold=0.5):
	group_members = []
	for n in RNG.permutation(A.shape[0]):
		# check if this unit matches any of the existing groups (in terms of its average affinity), and add it to the closest group if applicable
		affinities = [numpy.mean(A[n,group]) for group in group_members] # affinities[g] is the average affinity of the focal unit to group g
		nearest_group = get_largest_item(affinities, check_nan=False)
		if((nearest_group>=0) and (affinities[nearest_group]>=affinity_threshold)):
			# add this unit to the closest cluster
			group_members[nearest_group].append(n)
		else:
			# none of the existing group are close to this unit, so create a new group
			group_members.append([n])
	return group_members
		


# given a square affinity matrix listing ingroup-probabilities, cluster units into groups using the Greedy Groups Refined algorithm
# The algorithm assumes that A is symmetric, all entries are between 0 and 1, and all diagonal entries are 1.
# This algorithm builds upon the Greedy Groups algorithms, by including multiple refinement rounds.
def greedy_groups_refined(A, Nrefinements=100, affinity_threshold=0.5):
	group_members = []
	for n in RNG.permutation(A.shape[0]):
		# check if this unit matches any of the existing groups (in terms of its average affinity), and add it to the closest group if applicable
		affinities = [numpy.mean(A[n,group]) for group in group_members] # affinities[g] is the average affinity of the focal unit to group g
		nearest_group = get_largest_item(affinities, check_nan=False)
		if((nearest_group>=0) and (affinities[nearest_group]>=affinity_threshold)):
			# add this unit to the closest group
			group_members[nearest_group].append(n)
		else:
			# none of the existing groups are close to this unit, so create a new group
			group_members.append([n])
	member2group = map_members_to_groups(group_members, Nmembers=A.shape[0], unique=True)
	# refine groups, by re-evaluating unit-to-group affinities using the current group configurations, and merging groups that appear close to each other
	for refinement in range(Nrefinements):
		new_group_members = [[] for g in range(len(group_members))] # reserve space for revised groups, based on how many groups we got in the first round
		for n in RNG.permutation(A.shape[0]):
			affinities = [mean_without(N=len(group), Xmean=numpy.mean(A[n,group]), Xomitted=([1] if (member2group[n]==g) else [])) for g,group in enumerate(group_members)]\
						+ [numpy.mean(A[n,group]) for group in new_group_members[len(group_members):]]
			nearest_group = get_largest_item(affinities, check_nan=True)
			if((nearest_group>=0) and (affinities[nearest_group]>=affinity_threshold)):
				# add this unit to the closest group
				new_group_members[nearest_group].append(n)
			else:
				# none of the existing groups are close to this unit, so create a new group
				new_group_members.append([n])
		# replace old groups with new groups, omitting groups that don't have any members left
		group_members = list(filter(len,new_group_members))
		# check if we should merge some groups, using an algorithm similar to "greedy_groups", but applied at the group level
		group_mergings = []
		for g, group in enumerate(group_members):
			affinities = [numpy.mean(A[numpy.ix_(group,[m for g2 in group_merging for m in group_members[g2]])]) for group_merging in group_mergings] # affinities[gm] is the average affinity of the focal group to all members of all groups in group merging gm
			nearest_group_merging = get_largest_item(affinities, check_nan=False)
			if((nearest_group_merging>=0) and (affinities[nearest_group_merging]>=affinity_threshold)):
				# add this group to the closest group merging
				group_mergings[nearest_group_merging].append(g)
			else:
				# none of the existing group mergings are close to this group, so create a new group merging
				group_mergings.append([g])
		# construct new groups based on the group mergings, i.e. all groups from a given group-merging will be merged into a single group
		group_members = [[m for g in group_merging for m in group_members[g]] for group_merging in group_mergings]
		new_member2group = map_members_to_groups(group_members, Nmembers=A.shape[0], unique=True)
		if(numpy.all(new_member2group==member2group)): break # algorithm converged
		member2group = new_member2group
	return group_members





##############################
# example

# generate a hypothetical affinity matrix corresponding to predefined clusters
clusters = [[0,1,2,3], [4,5], [6,7], [8,9], [10,11], [12,13,14], [15,16,17,18,19,20]]
N = 1+max(max(cl) for cl in clusters)
A = RNG.uniform(0,0.5,size=(N,N))
for cluster in clusters:
	A[numpy.ix_(cluster,cluster)] = RNG.uniform(0.5,1,size=(len(cluster),len(cluster)))
A[range(A.shape[0]),range(A.shape[1])] = 1 # ensure diagonal entries are 1
A = 0.5*(A + A.T) # make affinity matrix symmetric

# plot the affinity matrix as a heatmap
pyplot.figure(figsize=(0.8*N, 0.8*N))
pyplot.imshow(A, cmap="gray")
pyplot.savefig("affinity_matrix.pdf", bbox_inches='tight')
pyplot.close()


# run GG clustering algorithm
clusters = greedy_groups(A)
print("GG yielded %d clusters:"%(len(clusters)))
print("\n".join(str(sorted(cluster)) for cluster in clusters))


# run GGR clustering algorithm python version
clusters = greedy_groups_refined(A)
print("GGR yielded %d clusters:"%(len(clusters)))
print("\n".join(str(sorted(cluster)) for cluster in clusters))