If we take your derivation of 2D a bit further, it becomes clear:
N*(k*M*log(M)) + M*(k*N*log(N)) = k*M*N*(log(M)+log(N))
becomes:
= k*M*N*(log(M*N))
For N dimensions (A,B,C, etc...), the complexity is:
O( A*B*C*... * log(A*B*C*...) )
Mathematically speaking, an N-Dimensional FFT is the same as a 1-D FFT with the size of the product of the dimensions, except that the twiddle factors are different. So it naturally follows that the computational complexity is the same.