//========= Copyright Valve Corporation, All rights reserved. ============//
//
// Purpose: 
//
//  A class allowing storage of a sparse NxN matirx as an array of sparse rows
//===========================================================================//

#ifndef SPARSEMATRIX_H
#define SPARSEMATRIX_H

#include "tier1/utlvector.h"

/// CSparseMatrix is a matrix which compresses each row individually, not storing the zeros.  NOte,
/// that while you can randomly set any element in a CSparseMatrix, you really want to do it top to
/// bottom or you will have bad perf as data is moved around to insert new elements.
class CSparseMatrix
{
public:
	struct NonZeroValueDescriptor_t
	{
		int m_nColumnNumber;
		float m_flValue;
	};

	struct RowDescriptor_t
	{
		int m_nNonZeroCount;								// number of non-zero elements in the row
		int m_nDataIndex;									// index of NonZeroValueDescriptor_t for the first non-zero value
	};

	int m_nNumRows;
	int m_nNumCols;
	CUtlVector<RowDescriptor_t> m_rowDescriptors;
	CUtlVector<NonZeroValueDescriptor_t> m_entries;
	int m_nHighestRowAppendedTo;

	void AdjustAllRowIndicesAfter( int nStartRow, int nDelta );
public:
	FORCEINLINE float Element( int nRow, int nCol ) const;
	void SetElement( int nRow, int nCol, float flValue );
	void SetDimensions( int nNumRows, int nNumCols );
	void AppendElement( int nRow, int nCol, float flValue );
	void FinishedAppending( void );

	FORCEINLINE int Height( void ) const { return m_nNumRows; }
	FORCEINLINE int Width( void ) const { return m_nNumCols; }

};



FORCEINLINE float CSparseMatrix::Element( int nRow, int nCol ) const
{
	Assert( nCol < m_nNumCols );
	int nCount = m_rowDescriptors[nRow].m_nNonZeroCount;
	if ( nCount )
	{
		NonZeroValueDescriptor_t const *pValue = &(m_entries[m_rowDescriptors[nRow].m_nDataIndex]);
		do
		{
			int nIdx = pValue->m_nColumnNumber;
			if ( nIdx == nCol )
			{
				return pValue->m_flValue;
			}
			if ( nIdx > nCol )
			{
				break;
			}
			pValue++;
		} while( --nCount );
	}
	return 0;
}



// type-specific overrides of matrixmath template for special case sparse routines

namespace MatrixMath
{
	/// sparse * dense matrix x matrix multiplication
	template<class BTYPE, class OUTTYPE> 
	void MatrixMultiply( CSparseMatrix const &matA, BTYPE const &matB, OUTTYPE *pMatrixOut )
	{
		Assert( matA.Width() == matB.Height() );
		pMatrixOut->SetDimensions( matA.Height(), matB.Width() );
		for( int i = 0; i < matA.Height(); i++ )
		{
			for( int j = 0; j < matB.Width(); j++ )
			{
				// compute inner product efficiently because of sparsity
				int nCnt = matA.m_rowDescriptors[i].m_nNonZeroCount;
				int nDataIdx = matA.m_rowDescriptors[i].m_nDataIndex;
				float flDot = 0.0;
				for( int nIdx = 0; nIdx < nCnt; nIdx++ )
				{
					float flAValue = matA.m_entries[nIdx + nDataIdx].m_flValue;
					int nCol = matA.m_entries[nIdx + nDataIdx].m_nColumnNumber;
					flDot += flAValue * matB.Element( nCol, j );
				}
				pMatrixOut->SetElement( i, j, flDot );
			}
		}
	}

	FORCEINLINE void AppendElement( CSparseMatrix &matrix, int nRow, int nCol, float flValue )
	{
		matrix.AppendElement( nRow, nCol, flValue );			// default implementation
	}

	FORCEINLINE void FinishedAppending( CSparseMatrix &matrix )
	{
		matrix.FinishedAppending();
	}

};





#endif // SPARSEMATRIX_H