#include "driver.h"
#include <stdio.h>
#include <sys/time.h>
#include <unistd.h>
#include <stdint.h>

extern void sve_support();
extern void sve_streaming_support();
extern void sme_support();
extern void sve_streaming_vlength( float * i_a,
                                   float * i_b );
extern int peak_neon_fmla( long reps );
extern int peak_sve_fmla_streaming( long reps );
extern int peak_sme_fmopa( long reps );
extern void example_sme_fmopa( float * i_a,
                               float * i_b,
                               float * i_c );

void micro_bench() {
  // vars
  long l_num_reps = 0;
  double l_gflops = 0;
  struct timeval l_start;
  struct timeval l_end;
  long l_seconds = 0;
  long l_useconds = 0;
  double total_time = 0;

  /*
   * Check for ISA extensions
   */
  // printf( "Checking for SVE support...\n" );
  // sve_support();
  printf( "Checking for SVE streaming support...\n" );
  sve_streaming_support();
  printf( "Checking for SME support...\n" );
  sme_support();

  /*
   * Determine SVE vector length in streaming mode
   */
  printf( "Determining vector length of SVE in streaming mode \n" );
  float l_a[32];
  float l_b[32] = {0};
  for( int64_t l_i = 0; l_i < 32; l_i++ ){
    l_a[l_i] = (float) l_i + 1;
  }
  sve_streaming_vlength( l_a, l_b );

  int64_t l_num_bits = 0;
  for( int64_t l_i = 0; l_i < 32; l_i++ ){
    if( l_b[l_i] > 0 ){
      l_num_bits += 32;
    }
  }
  printf( "  %lld bits\n", l_num_bits );

  /*
   * Determine peak when using Neon
   */
  printf("Determining Neon peak performance...\n");
  l_num_reps = 1000000000;

  gettimeofday(&l_start, NULL);
  l_gflops = peak_neon_fmla(l_num_reps);
  gettimeofday(&l_end, NULL);
  
  l_seconds  = l_end.tv_sec  - l_start.tv_sec;
  l_useconds = l_end.tv_usec - l_start.tv_usec;
  total_time = l_seconds + l_useconds/1000000.0;
  
  l_gflops *= l_num_reps;
  l_gflops *= 1.0E-9;
  l_gflops /= total_time;

  printf( "  Repetitions: %ld\n", l_num_reps );
  printf( "  Total time:  %f\n", total_time );
  printf( "  FP32 GFLOPS: %f\n", l_gflops );

  /*
   * Determine peak when using SVE in streaming mode
   */
  printf( "Determining peak performance for SVE in streaming mode...\n" );
  l_num_reps = 100000000;
  gettimeofday(&l_start, NULL);
  l_gflops = peak_sve_fmla_streaming(l_num_reps);
  gettimeofday(&l_end, NULL);

  l_seconds  = l_end.tv_sec  - l_start.tv_sec;
  l_useconds = l_end.tv_usec - l_start.tv_usec;
  total_time = l_seconds + l_useconds/1000000.0;

  l_gflops *= l_num_reps;
  l_gflops *= 1.0E-9;
  l_gflops /= total_time;

  printf( "  Repetitions: %ld\n", l_num_reps );
  printf( "  Total time:  %f\n", total_time );
  printf( "  FP32 GFLOPS: %f\n", l_gflops );

  /*
   * Determine peak when using SME
   */
  printf( "Determining peak performance for SME...\n" );
  l_num_reps = 250000000;
  gettimeofday(&l_start, NULL);
  l_gflops = peak_sme_fmopa(l_num_reps);
  gettimeofday(&l_end, NULL);

  l_seconds  = l_end.tv_sec  - l_start.tv_sec;
  l_useconds = l_end.tv_usec - l_start.tv_usec;
  total_time = l_seconds + l_useconds/1000000.0;

  l_gflops *= l_num_reps;
  l_gflops *= 1.0E-9;
  l_gflops /= total_time;

  printf( "  Repetitions: %ld\n", l_num_reps );
  printf( "  Total time:  %f\n", total_time );
  printf( "  FP32 GFLOPS: %f\n", l_gflops );

  /*
   * Showcase outer-product SME FMOPA
   */
  printf( "Running example SME FMOPA...\n" );
  float l_a_sme_fmopa[32];
  float l_b_sme_fmopa[32];
  float l_c_sme_fmopa[32*32] = {0};

  for( int64_t l_i = 0; l_i < 32; l_i++ ){
    l_a_sme_fmopa[l_i] = l_i + 1;
    l_b_sme_fmopa[l_i] = l_i + 1;
  }

  example_sme_fmopa( l_a_sme_fmopa,
                     l_b_sme_fmopa,
                     l_c_sme_fmopa );

  for( int64_t l_i = 0; l_i < 32; l_i++ ){
    for( int64_t l_j = 0; l_j < 16; l_j++ ){
      printf( "  %f", l_c_sme_fmopa[l_i*16+l_j] );
    }
    printf( "\n" );
  }
}
