#include "driver.h"
#include <stdio.h>
#include <sys/time.h>
#include <unistd.h>
#include <stdint.h>
#include <arm_bf16.h>
#include <arm_neon.h>
#include <stdlib.h>
#include <inttypes.h>
#include <pthread.h>
#include <Accelerate/Accelerate.h>

extern void sve_support();
extern void sve_streaming_support();
extern void sme_support();
extern void sme2_support();
extern void sve_streaming_vlength( float * i_a,
                                   float * i_b );
extern int peak_neon_fmla( long reps );
extern int peak_sve_fmla_streaming( long reps );
extern int peak_sme_fmopa_1( long reps );
extern int peak_sme_fmopa_2( long reps );
extern int peak_sme_fmopa_4( long reps );
extern int peak_sme_fmopa_4_reorder( long reps );
extern int peak_sme_fmopa_widening( long reps );
extern int peak_sme_bfmopa_widening( long reps );
extern void example_sme_fmopa( float * i_a,
                               float * i_b,
                               float * i_c );
extern void example_sme_bfmopa_widening( bfloat16_t * i_a,
                                         bfloat16_t * i_b,
                                         float      * i_c );

void bench_cblas( int64_t i_m,
                  int64_t i_n,
                  int64_t i_k,
                  int64_t i_lda,
                  int64_t i_ldb,
                  int64_t i_ldc,
                  int     i_trans_a,
                  int     i_trans_b,
                  int64_t i_num_reps_initial,
                  double  i_target_time ) {
  printf( "Running CBLAS SGEMM...\n" );
  printf( "  M/N/K:       %" PRId64 "/%" PRId64 "/%" PRId64 "\n", i_m, i_n, i_k );
  printf( "  ldA/ldB/ldC: %" PRId64 "/%" PRId64 "/%" PRId64 "\n", i_lda, i_ldb, i_ldc );
  printf( "  TransA/TransB: %d/%d\n", i_trans_a, i_trans_b );

  // vars
  int64_t l_num_reps = 0;
  int64_t l_num_flops = 0;
  double l_gflops = 0;
  struct timeval l_start;
  struct timeval l_end;
  long l_seconds = 0;
  long l_useconds = 0;
  double l_total_time = 0;

  // allocate memory
  float * l_a = NULL;
  float * l_b = NULL;
  float * l_c = NULL;

  posix_memalign( (void**) &l_a, 128, i_lda * i_k * sizeof(float) );
  posix_memalign( (void**) &l_b, 128, i_ldb * i_n * sizeof(float) );
  posix_memalign( (void**) &l_c, 128, i_ldc * i_n * sizeof(float) );

  // init the matrices
  for( int64_t l_en = 0; l_en < i_lda * i_k; l_en++ ) {
    l_a[l_en] = 1.0f;
  }

  for( int64_t l_en = 0; l_en < i_ldb * i_n; l_en++ ) {
    l_b[l_en] = 1.0f;
  }

  for( int64_t l_en = 0; l_en < i_ldc * i_n; l_en++ ) {
    l_c[l_en] = 1.0f;
  }

  // warmup
  cblas_sgemm( CblasColMajor,
               i_trans_a == 0 ? CblasNoTrans : CblasTrans,
               i_trans_b == 0 ? CblasNoTrans : CblasTrans,
               i_m,
               i_n,
               i_k,
               1,
               l_a,
               i_lda,
               l_b,
               i_ldb,
               1,
               l_c,
               i_ldc );

  gettimeofday( &l_start, NULL );
  for( int64_t l_re = 0; l_re < i_num_reps_initial; l_re++) {
    cblas_sgemm( CblasColMajor,
                 i_trans_a == 0 ? CblasNoTrans : CblasTrans,
                 i_trans_b == 0 ? CblasNoTrans : CblasTrans,
                 i_m,
                 i_n,
                 i_k,
                 1,
                 l_a,
                 i_lda,
                 l_b,
                 i_ldb,
                 1,
                 l_c,
                 i_ldc );
  }
  gettimeofday( &l_end, NULL );

  l_seconds    = l_end.tv_sec  - l_start.tv_sec;
  l_useconds   = l_end.tv_usec - l_start.tv_usec;
  l_total_time = l_seconds + l_useconds/1000000.0;

  l_num_reps = (i_target_time * i_num_reps_initial) / l_total_time;
  l_num_reps = l_num_reps > 1 ? l_num_reps : 1;
  l_num_flops = 2 * i_m * i_n * i_k * l_num_reps;

  gettimeofday( &l_start, NULL );
  for( int64_t l_re = 0; l_re < l_num_reps; l_re++ ) {
    cblas_sgemm( CblasColMajor,
                 i_trans_a == 0 ? CblasNoTrans : CblasTrans,
                 i_trans_b == 0 ? CblasNoTrans : CblasTrans,
                 i_m,
                 i_n,
                 i_k,
                 1,
                 l_a,
                 i_lda,
                 l_b,
                 i_ldb,
                 1,
                 l_c,
                 i_ldc );
  }
  gettimeofday( &l_end, NULL );

  l_seconds    = l_end.tv_sec  - l_start.tv_sec;
  l_useconds   = l_end.tv_usec - l_start.tv_usec;
  l_total_time = l_seconds + l_useconds/1000000.0;
  l_gflops = l_num_flops / l_total_time;
  l_gflops *= 1.0E-9;

  printf( "  Repetitions: %ld\n", l_num_reps );
  printf( "  Total time:  %f\n", l_total_time );
  printf( "  FP32 GFLOPS: %f\n", l_gflops );

  free( l_a );
  free( l_b );
  free( l_c );
}

void micro_bench() {
  // set affinity
  //pthread_set_qos_class_self_np( QOS_CLASS_USER_INTERACTIVE, 0 );
  pthread_set_qos_class_self_np( QOS_CLASS_BACKGROUND, 0 );

  // vars
  long l_num_reps = 0;
  double l_gflops = 0;
  struct timeval l_start;
  struct timeval l_end;
  long l_seconds = 0;
  long l_useconds = 0;
  double total_time = 0;

  /*
   * Check for ISA extensions
   */
  // printf( "Checking for SVE support...\n" );
  // sve_support();
  printf( "Checking for SVE streaming support...\n" );
  sve_streaming_support();
  printf( "Checking for SME support...\n" );
  sme_support();
  printf( "Checking for SME2 support...\n" );
  sme2_support();

  /*
   * Determine SVE vector length in streaming mode
   */
  printf( "Determining vector length of SVE in streaming mode \n" );
  float l_a[32];
  float l_b[32] = {0};
  for( int64_t l_i = 0; l_i < 32; l_i++ ){
    l_a[l_i] = (float) l_i + 1;
  }
  sve_streaming_vlength( l_a, l_b );

  int64_t l_num_bits = 0;
  for( int64_t l_i = 0; l_i < 32; l_i++ ){
    if( l_b[l_i] > 0 ){
      l_num_bits += 32;
    }
  }
  printf( "  %lld bits\n", l_num_bits );

  /*
   * Determine peak when using Neon
   */
  printf("Determining Neon peak performance...\n");
  l_num_reps = 1000000000;

  gettimeofday( &l_start, NULL );
  l_gflops = peak_neon_fmla(l_num_reps);
  gettimeofday( &l_end, NULL );
  
  l_seconds  = l_end.tv_sec  - l_start.tv_sec;
  l_useconds = l_end.tv_usec - l_start.tv_usec;
  total_time = l_seconds + l_useconds/1000000.0;
  
  l_gflops *= l_num_reps;
  l_gflops *= 1.0E-9;
  l_gflops /= total_time;

  printf( "  Repetitions: %ld\n", l_num_reps );
  printf( "  Total time:  %f\n", total_time );
  printf( "  FP32 GFLOPS: %f\n", l_gflops );

  /*
   * Determine peak when using SVE in streaming mode
   */
  printf( "Determining peak performance for SVE in streaming mode...\n" );
  l_num_reps = 100000000;
  gettimeofday( &l_start, NULL );
  l_gflops = peak_sve_fmla_streaming( l_num_reps );
  gettimeofday( &l_end, NULL );

  l_seconds  = l_end.tv_sec  - l_start.tv_sec;
  l_useconds = l_end.tv_usec - l_start.tv_usec;
  total_time = l_seconds + l_useconds/1000000.0;

  l_gflops *= l_num_reps;
  l_gflops *= 1.0E-9;
  l_gflops /= total_time;

  printf( "  Repetitions: %ld\n", l_num_reps );
  printf( "  Total time:  %f\n", total_time );
  printf( "  FP32 GFLOPS: %f\n", l_gflops );

  /*
   * Determine peak when using SME FMOPA accumulating in a single tile
   */
  printf( "Determining peak performance for SME FMOPA accumulating in a single tile...\n" );
  l_num_reps = 250000000;
  gettimeofday( &l_start, NULL );
  l_gflops = peak_sme_fmopa_1( l_num_reps );
  gettimeofday( &l_end, NULL );

  l_seconds  = l_end.tv_sec  - l_start.tv_sec;
  l_useconds = l_end.tv_usec - l_start.tv_usec;
  total_time = l_seconds + l_useconds/1000000.0;

  l_gflops *= l_num_reps;
  l_gflops *= 1.0E-9;
  l_gflops /= total_time;

  printf( "  Repetitions: %ld\n", l_num_reps );
  printf( "  Total time:  %f\n", total_time );
  printf( "  FP32 GFLOPS: %f\n", l_gflops );

  /*
   * Determine peak when using SME FMOPA accumulating in two tile
   */
  printf( "Determining peak performance for SME FMOPA accumulating in two tiles...\n" );
  l_num_reps = 250000000;
  gettimeofday( &l_start, NULL );
  l_gflops = peak_sme_fmopa_2( l_num_reps );
  gettimeofday( &l_end, NULL );

  l_seconds  = l_end.tv_sec  - l_start.tv_sec;
  l_useconds = l_end.tv_usec - l_start.tv_usec;
  total_time = l_seconds + l_useconds/1000000.0;

  l_gflops *= l_num_reps;
  l_gflops *= 1.0E-9;
  l_gflops /= total_time;

  printf( "  Repetitions: %ld\n", l_num_reps );
  printf( "  Total time:  %f\n", total_time );
  printf( "  FP32 GFLOPS: %f\n", l_gflops );

  /*
   * Determine peak when using SME FMOPA accumulating in four tiles
   */
  printf( "Determining peak performance for SME FMOPA accumulating in four tiles...\n" );
  l_num_reps = 250000000;
  gettimeofday( &l_start, NULL );
  l_gflops = peak_sme_fmopa_4( l_num_reps );
  gettimeofday( &l_end, NULL );

  l_seconds  = l_end.tv_sec  - l_start.tv_sec;
  l_useconds = l_end.tv_usec - l_start.tv_usec;
  total_time = l_seconds + l_useconds/1000000.0;

  l_gflops *= l_num_reps;
  l_gflops *= 1.0E-9;
  l_gflops /= total_time;

  printf( "  Repetitions: %ld\n", l_num_reps );
  printf( "  Total time:  %f\n", total_time );
  printf( "  FP32 GFLOPS: %f\n", l_gflops );


  /*
   * Determine peak when using SME FMOPA accumulating in four tiles which requires reordering
   */
  printf( "Determining peak performance for SME FMOPA accumulating in four tiles (requires reordering)...\n" );
  l_num_reps = 250000000;
  gettimeofday( &l_start, NULL );
  l_gflops = peak_sme_fmopa_4_reorder( l_num_reps );
  gettimeofday( &l_end, NULL );

  l_seconds  = l_end.tv_sec  - l_start.tv_sec;
  l_useconds = l_end.tv_usec - l_start.tv_usec;
  total_time = l_seconds + l_useconds/1000000.0;

  l_gflops *= l_num_reps;
  l_gflops *= 1.0E-9;
  l_gflops /= total_time;

  printf( "  Repetitions: %ld\n", l_num_reps );
  printf( "  Total time:  %f\n", total_time );
  printf( "  FP32 GFLOPS: %f\n", l_gflops );

  /*
   * Determine peak when using SME FMOPA (widening)
   */
  printf( "Determining peak performance for SME FMOPA(widening)...\n" );
  l_num_reps = 250000000;
  gettimeofday( &l_start, NULL );
  l_gflops = peak_sme_fmopa_widening( l_num_reps );
  gettimeofday( &l_end, NULL );

  l_seconds  = l_end.tv_sec  - l_start.tv_sec;
  l_useconds = l_end.tv_usec - l_start.tv_usec;
  total_time = l_seconds + l_useconds/1000000.0;

  l_gflops *= l_num_reps;
  l_gflops *= 1.0E-9;
  l_gflops /= total_time;

  printf( "  Repetitions: %ld\n", l_num_reps );
  printf( "  Total time:  %f\n", total_time );
  printf( "  FP32 GFLOPS: %f\n", l_gflops );

  /*
   * Determine peak when using SME BFMOPA(widening)
   */
  printf( "Determining peak performance for SME BFMOPA(widening)...\n" );
  l_num_reps = 250000000;
  gettimeofday( &l_start, NULL );
  l_gflops = peak_sme_bfmopa_widening( l_num_reps );
  gettimeofday( &l_end, NULL );

  l_seconds  = l_end.tv_sec  - l_start.tv_sec;
  l_useconds = l_end.tv_usec - l_start.tv_usec;
  total_time = l_seconds + l_useconds/1000000.0;

  l_gflops *= l_num_reps;
  l_gflops *= 1.0E-9;
  l_gflops /= total_time;

  printf( "  Repetitions: %ld\n", l_num_reps );
  printf( "  Total time:  %f\n", total_time );
  printf( "  FP32 GFLOPS: %f\n", l_gflops );


  /*
   * Run CBLAS benchmarks
   */
  printf( "Running CBLAS benchmarks...\n" );
  int64_t l_size = 16;
  int64_t l_num_reps_initial = 65536;
  double  l_target_time = 1.0;
  for( int64_t l_si = 0; l_si < 10; l_si++ ) {
    bench_cblas( l_size,
                 l_size,
                 l_size,
                 l_size,
                 l_size,
                 l_size,
                 0,
                 0,
                 l_num_reps_initial,
                 l_target_time );

    bench_cblas( l_size,
                 l_size,
                 l_size,
                 l_size,
                 l_size,
                 l_size,
                 0,
                 1,
                 l_num_reps_initial,
                 l_target_time );

    l_size *= 2;
    l_num_reps_initial /= 8;
    l_num_reps_initial = l_num_reps_initial > 1 ? l_num_reps_initial : 1;
  }

  /*
   * Showcase outer-product SME FMOPA
   */
  printf( "Running example SME FMOPA...\n" );
  float l_a_sme_fmopa[32];
  float l_b_sme_fmopa[32];
  float l_c_sme_fmopa[32*32] = {0};

  for( int64_t l_i = 0; l_i < 32; l_i++ ){
    l_a_sme_fmopa[l_i] = l_i + 1;
    l_b_sme_fmopa[l_i] = l_i + 1;
  }

  example_sme_fmopa( l_a_sme_fmopa,
                     l_b_sme_fmopa,
                     l_c_sme_fmopa );

  for( int64_t l_i = 0; l_i < 32; l_i++ ){
    for( int64_t l_j = 0; l_j < 16; l_j++ ){
      printf( "  %f", l_c_sme_fmopa[l_i*16+l_j] );
    }
    printf( "\n" );
  }

  /*
   * Showcase BFMOPA(widening)
   */
  printf( "Running example SME BFMOPA(widening)...\n" );
  bfloat16_t l_a_bf16[16*2];
  bfloat16_t l_b_bf16[16*2];
  float      l_c_fp32[16*16] = {0};

  for( int64_t l_i = 0; l_i < 16*2; l_i++ ){
    float l_val = (float) l_i + 1;
    l_a_bf16[l_i] = vcvth_bf16_f32( l_val );
    l_b_bf16[l_i] = vcvth_bf16_f32( l_val );
  }

  example_sme_bfmopa_widening( l_a_bf16,
                               l_b_bf16,
                               l_c_fp32 );

  for( int64_t l_i = 0; l_i < 16; l_i++ ){
    for( int64_t l_j = 0; l_j < 16; l_j++ ){
      printf( "  %f", l_c_fp32[l_i*16+l_j] );
    }
    printf( "\n" );
  }
}
