java实现任意矩阵Strassen算法_java

本例输入为两个任意尺寸的矩阵m * n, n * m,输出为两个矩阵的乘积。计算任意尺寸矩阵相乘时,使用了Strassen算法。程序为自编,经过测试,请放心使用。基本算法是:
1.对于方阵(正方形矩阵),找到最大的l, 使得l = 2 ^ k, k为整数并且l < m。边长为l的方形矩阵则采用Strassen算法,其余部分以及方形矩阵中遗漏的部分用蛮力法。
2.对于非方阵,依照行列相应添加0使其成为方阵。
StrassenMethodTest.java

package matrixalgorithm;

import java.util.Scanner;

public class StrassenMethodTest {

  private StrassenMethod strassenMultiply;

   StrassenMethodTest(){
    strassenMultiply = new StrassenMethod();
  }//end cons

   public static void main(String[] args){
    Scanner input = new Scanner(System.in);
    System.out.println("Input row size of the first matrix: ");
    int arow = input.nextInt();
    System.out.println("Input column size of the first matrix: ");
    int acol = input.nextInt();
    System.out.println("Input row size of the second matrix: ");
    int brow = input.nextInt();
    System.out.println("Input column size of the second matrix: ");
    int bcol = input.nextInt();

    double[][] A = new double[arow][acol];
    double[][] B = new double[brow][bcol];
    double[][] C = new double[arow][bcol];
    System.out.println("Input data for matrix A: ");

    /*In all of the codes later in this project,
    r means row while c means column.
    */
    for (int r = 0; r < arow; r++) {
      for (int c = 0; c < acol; c++) {
        System.out.printf("Data of A[%d][%d]: ", r, c);
        A[r][c] = input.nextDouble();
      }//end inner loop
    }//end loop

    System.out.println("Input data for matrix B: ");
    for (int r = 0; r < brow; r++) {
      for (int c = 0; c < bcol; c++) {
        System.out.printf("Data of A[%d][%d]: ", r, c);
        B[r][c] = input.nextDouble();
      }//end inner loop
    }//end loop

    StrassenMethodTest algorithm = new StrassenMethodTest();
    C = algorithm.multiplyRectMatrix(A, B, arow, acol, brow, bcol);

    //Display the calculation result:
    System.out.println("Result from matrix C: ");
    for (int r = 0; r < arow; r++) {
      for (int c = 0; c < bcol; c++) {
        System.out.printf("Data of C[%d][%d]: %f\n", r, c, C[r][c]);
      }//end inner loop
    }//end outter loop
   }//end main

  //Deal with matrices that are not square:
  public double[][] multiplyRectMatrix(double[][] A, double[][] B,
      int arow, int acol, int brow, int bcol) {
    if (arow != bcol) //Invalid multiplicatio
      return new double[][]{{0}};

    double[][] C = new double[arow][bcol];

    if (arow < acol) {

      double[][] newA = new double[acol][acol];
      double[][] newB = new double[brow][brow];

      int n = acol;

      for (int r = 0; r < acol; r++)
        for (int c = 0; c < acol; c++)
          newA[r][c] = 0.0;

      for (int r = 0; r < brow; r++)
        for (int c = 0; c < brow; c++)
          newB[r][c] = 0.0;

      for (int r = 0; r < arow; r++)
        for (int c = 0; c < acol; c++)
          newA[r][c] = A[r][c];

      for (int r = 0; r < brow; r++)
        for (int c = 0; c < bcol; c++)
          newB[r][c] = B[r][c];

      double[][] C2 = multiplySquareMatrix(newA, newB, n);
      for(int r = 0; r < arow; r++)
        for(int c = 0; c < bcol; c++)
          C[r][c] = C2[r][c];
    }//end if

    else if(arow == acol)
      C = multiplySquareMatrix(A, B, arow);

    else {
      int n = arow;
      double[][] newA = new double[arow][arow];
      double[][] newB = new double[bcol][bcol];

      for (int r = 0; r < arow; r++)
        for (int c = 0; c < arow; c++)
          newA[r][c] = 0.0;

      for (int r = 0; r < bcol; r++)
        for (int c = 0; c < bcol; c++)
          newB[r][c] = 0.0;

      for (int r = 0; r < arow; r++)
        for (int c = 0; c < acol; c++)
          newA[r][c] = A[r][c];

      for (int r = 0; r < brow; r++)
        for (int c = 0; c < bcol; c++)
          newB[r][c] = B[r][c];

      double[][] C2 = multiplySquareMatrix(newA, newB, n);
      for(int r = 0; r < arow; r++)
        for(int c = 0; c < bcol; c++)
          C[r][c] = C2[r][c];
    }//end else

     return C;
   }//end method

  //Deal with matrices that are square matrices.
   public double[][] multiplySquareMatrix(double[][] A2, double[][] B2, int n){

     double[][] C2 = new double[n][n];

     for(int r = 0; r < n; r++)
       for(int c = 0; c < n; c++)
         C2[r][c] = 0;

     if(n == 1){
      C2[0][0] = A2[0][0] * B2[0][0];
      return C2;
     }//end if

     int exp2k = 2;

     while(exp2k <= (n / 2) ){
       exp2k *= 2;
     }//end loop

     if(exp2k == n){
       C2 = strassenMultiply.strassenMultiplyMatrix(A2, B2, n);
       return C2;
     }//end else

     //The "biggest" strassen matrix:
     double[][][] A = new double[6][exp2k][exp2k];
     double[][][] B = new double[6][exp2k][exp2k];
     double[][][] C = new double[6][exp2k][exp2k];

     for(int r = 0; r < exp2k; r++){
       for(int c = 0; c < exp2k; c++){
         A[0][r][c] = A2[r][c];
         B[0][r][c] = B2[r][c];
       }//end inner loop
     }//end outter loop

    C[0] = strassenMultiply.strassenMultiplyMatrix(A[0], B[0], exp2k);

    for(int r = 0; r < exp2k; r++)
      for(int c = 0; c < exp2k; c++)
        C2[r][c] = C[0][r][c];

    int middle = exp2k / 2;

    for(int r = 0; r < middle; r++){
      for(int c = exp2k; c < n; c++){
        A[1][r][c - exp2k] = A2[r][c];
        B[3][r][c - exp2k] = B2[r][c];
      }//end inner loop
    }//end outter loop

    for(int r = exp2k; r < n; r++){
      for(int c = 0; c < middle; c++){
        A[3][r - exp2k][c] = A2[r][c];
        B[1][r - exp2k][c] = B2[r][c];
      }//end inner loop
    }//end outter loop

    for(int r = middle; r < exp2k; r++){
      for(int c = exp2k; c < n; c++){
        A[2][r - middle][c - exp2k] = A2[r][c];
        B[4][r - middle][c - exp2k] = B2[r][c];
      }//end inner loop
    }//end outter loop

    for(int r = exp2k; r < n; r++){
      for(int c = middle; c < n - exp2k + 1; c++){
        A[4][r - exp2k][c - middle] = A2[r][c];
        B[2][r - exp2k][c - middle] = B2[r][c];
      }//end inner loop
    }//end outter loop

    for(int i = 1; i <= 4; i++)
      C[i] = multiplyRectMatrix(A[i], B[i], middle, A[i].length, A[i].length, middle);

    /*
    Calculate the final results of grids in the "biggest 2^k square,
    according to the rules of matrice multiplication.
    */
    for (int row = 0; row < exp2k; row++) {
       for (int col = 0; col < exp2k; col++) {
         for (int k = exp2k; k < n; k++) {
           C2[row][col] += A2[row][k] * B2[k][col];
         }//end loop
       }//end inner loop
     }//end outter loop

    //Use brute force to solve the rest, will be improved later:
    for(int col = exp2k; col < n; col++){
      for(int row = 0; row < n; row++){
        for(int k = 0; k < n; k++)
          C2[row][col] += A2[row][k] * B2[k][row];
      }//end inner loop
    }//end outter loop

    for(int row = exp2k; row < n; row++){
      for(int col = 0; col < exp2k; col++){
        for(int k = 0; k < n; k++)
          C2[row][col] += A2[row][k] * B2[k][row];
      }//end inner loop
    }//end outter loop   

    return C2;
   }//end method

}//end class

StrassenMethod.java

package matrixalgorithm;

import java.util.Scanner;

public class StrassenMethod {

  private double[][][][] A = new double[2][2][][];
  private double[][][][] B = new double[2][2][][];
  private double[][][][] C = new double[2][2][][];

  /*//Codes for testing this class:
    public static void main(String[] args) {
    Scanner input = new Scanner(System.in);
    System.out.println("Input size of the matrix: ");
    int n = input.nextInt();

    double[][] A = new double[n][n];
    double[][] B = new double[n][n];
    double[][] C = new double[n][n];
    System.out.println("Input data for matrix A: ");
    for (int r = 0; r < n; r++) {
      for (int c = 0; c < n; c++) {
        System.out.printf("Data of A[%d][%d]: ", r, c);
        A[r][c] = input.nextDouble();
      }//end inner loop
    }//end loop

    System.out.println("Input data for matrix B: ");
    for (int r = 0; r < n; r++) {
      for (int c = 0; c < n; c++) {
        System.out.printf("Data of A[%d][%d]: ", r, c);
        B[r][c] = input.nextDouble();
      }//end inner loop
    }//end loop

    StrassenMethod algorithm = new StrassenMethod();
    C = algorithm.strassenMultiplyMatrix(A, B, n);

    System.out.println("Result from matrix C: ");
    for (int r = 0; r < n; r++) {
      for (int c = 0; c < n; c++) {
        System.out.printf("Data of C[%d][%d]: %f\n", r, c, C[r][c]);
      }//end inner loop
    }//end outter loop

  }//end main*/

   public double[][] strassenMultiplyMatrix(double[][] A2, double B2[][], int n){
    double[][] C2 = new double[n][n];
    //Initialize the matrix:
    for(int rowIndex = 0; rowIndex < n; rowIndex++)
      for(int colIndex = 0; colIndex < n; colIndex++)
        C2[rowIndex][colIndex] = 0.0;

    if(n == 1)
      C2[0][0] = A2[0][0] * B2[0][0];
    //"Slice matrices into 2 * 2 parts:
    else{
      double[][][][] A = new double[2][2][n / 2][n / 2];
      double[][][][] B = new double[2][2][n / 2][n / 2];
      double[][][][] C = new double[2][2][n / 2][n / 2];

      for(int r = 0; r < n / 2; r++){
        for(int c = 0; c < n / 2; c++){
          A[0][0][r][c] = A2[r][c];
          A[0][1][r][c] = A2[r][n / 2 + c];
          A[1][0][r][c] = A2[n / 2 + r][c];
          A[1][1][r][c] = A2[n / 2 + r][n / 2 + c];

          B[0][0][r][c] = B2[r][c];
          B[0][1][r][c] = B2[r][n / 2 + c];
          B[1][0][r][c] = B2[n / 2 + r][c];
          B[1][1][r][c] = B2[n / 2 + r][n / 2 + c];
        }//end loop
      }//end loop

      n = n / 2;

      double[][][] S = new double[10][n][n];
      S[0] = minusMatrix(B[0][1], B[1][1], n);
      S[1] = addMatrix(A[0][0], A[0][1], n);
      S[2] = addMatrix(A[1][0], A[1][1], n);
      S[3] = minusMatrix(B[1][0], B[0][0], n);
      S[4] = addMatrix(A[0][0], A[1][1], n);
      S[5] = addMatrix(B[0][0], B[1][1], n);
      S[6] = minusMatrix(A[0][1], A[1][1], n);
      S[7] = addMatrix(B[1][0], B[1][1], n);
      S[8] = minusMatrix(A[0][0], A[1][0], n);
      S[9] = addMatrix(B[0][0], B[0][1], n);

      double[][][] P = new double[7][n][n];
      P[0] = strassenMultiplyMatrix(A[0][0], S[0], n);
      P[1] = strassenMultiplyMatrix(S[1], B[1][1], n);
      P[2] = strassenMultiplyMatrix(S[2], B[0][0], n);
      P[3] = strassenMultiplyMatrix(A[1][1], S[3], n);
      P[4] = strassenMultiplyMatrix(S[4], S[5], n);
      P[5] = strassenMultiplyMatrix(S[6], S[7], n);
      P[6] = strassenMultiplyMatrix(S[8], S[9], n);

      C[0][0] = addMatrix(minusMatrix(addMatrix(P[4], P[3], n), P[1], n), P[5], n);
      C[0][1] = addMatrix(P[0], P[1], n);
      C[1][0] = addMatrix(P[2], P[3], n);
      C[1][1] = minusMatrix(minusMatrix(addMatrix(P[4], P[0], n), P[2], n), P[6], n);

      n *= 2;

       for(int r = 0; r < n / 2; r++){
        for(int c = 0; c < n / 2; c++){
          C2[r][c] = C[0][0][r][c];
          C2[r][n / 2 + c] = C[0][1][r][c];
          C2[n / 2 + r][c] = C[1][0][r][c];
          C2[n / 2 + r][n / 2 + c] = C[1][1][r][c];
        }//end inner loop
      }//end outter loop
    }//end else     

    return C2;
  }//end method

   //Add two matrices according to matrix addition.
   private double[][] addMatrix(double[][] A, double[][] B, int n){
    double C[][] = new double[n][n];

    for(int r = 0; r < n; r++)
      for(int c = 0; c < n; c++)
        C[r][c] = A[r][c] + B[r][c];

    return C;
  }//end method 

   //Substract two matrices according to matrix addition.
   private double[][] minusMatrix(double[][] A, double[][] B, int n){
    double C[][] = new double[n][n];

    for(int r = 0; r < n; r++)
      for(int c = 0; c < n; c++)
        C[r][c] = A[r][c] - B[r][c];

    return C;
  }//end method

}//end class

希望本文所述对大家学习java程序设计有所帮助。

以上是小编为您精心准备的的内容,在的博客、问答、公众号、人物、课程等栏目也有的相关内容,欢迎继续使用右上角搜索按钮进行搜索java矩阵算法
Strassen算法
strassen矩阵乘法、strassen算法、strassen矩阵乘法代码、strassen矩阵乘法优化、strassen矩阵规模太大,以便于您获取更多的相关知识。

时间: 2024-08-22 14:33:38

java实现任意矩阵Strassen算法_java的相关文章

java实现Base64加密解密算法_java

Base64是网络上最常见的用于传输8Bit字节代码的编码方式之一,大家可以查看RFC2045-RFC2049,上面有MIME的详细规范.Base64编码可用于在HTTP环境下传递较长的标识信息.例如,在Java Persistence系统Hibernate中,就采用了Base64来将一个较长的唯一标识符(一般为128-bit的UUID)编码为一个字符串,用作HTTP表单和HTTP GET URL中的参数.在其他应用程序中,也常常需要把二进制数据编码为适合放在URL(包括隐藏表单域)中的形式.

图解程序员必须掌握的Java常用8大排序算法_java

这篇文章主要介绍了Java如何实现八个常用的排序算法:插入排序.冒泡排序.选择排序.希尔排序 .快速排序.归并排序.堆排序和LST基数排序,分享给大家一起学习. 分类1)插入排序(直接插入排序.希尔排序) 2)交换排序(冒泡排序.快速排序) 3)选择排序(直接选择排序.堆排序) 4)归并排序 5)分配排序(基数排序) 所需辅助空间最多:归并排序 所需辅助空间最少:堆排序 平均速度最快:快速排序 不稳定:快速排序,希尔排序,堆排序. 先来看看8种排序之间的关系: 1.直接插入排序 (1)基本思想:

两种JAVA实现短网址服务算法_java

短网址(Short URL) ,顾名思义就是看起来很短的网址.自从twitter推出短网址服务以后,各大互联网公司都推出了自己的短网址服务.短网址最大的优点就是短,字符少,便于发布.传播.复制和存储. 通过网上的搜索,感觉流传了2种短网址算法,一种是基于MD5码的,一种是基于自增序列的. 1.基于MD5码 : 这种算法计算的短网址长度一般是5位或者6位,计算过程中可能出现碰撞(概率很小),可表达的url数量为62 的5次方或6次方.感觉google(http://goo.gl),微博用的是类似这

Java抢红包的红包生成算法_java

马上过年了.过年微信红包很火,最近有个项目也要做抢红包,于是写了个红包的生成算法. 红包生成算法的需求 预先生成所有的红包还是一个请求随机生成一个红包 简单来说,就是把一个大整数m分解(直接以"分为单位,如1元即100)分解成n个小整数的过程,小整数的范围是[min, max]. 最简单的思路,先保底,每个小红包保证有min,然后每个请求都随机生成一个0到(max-min)范围的整数,再加上min就是红包的钱数. 这个算法虽然简单,但是有一个弊端:最后生成的红包可能都是min钱数的.也就是说可能

[算法系列之十五]Strassen矩阵相乘算法

引言 Strassen矩阵乘法是一种典型的分治算法.目前为止,我们已经见过一些分治策略的算法了,例如归并排序和Karatsuba大数快速乘法.现在,让我们看看分治策略的背后原理是什么. 同动态规划不同,在动态规划中,为了得到最终的答案,我们需要把一个大的问题"展开"为几个子问题("expand" the solutions of sub-problems),但是在这里,我们会更多的谈到如何把一些子解决方案组合到一起.对于一般问题,他们的子问题的解决方案是对等的,他们

java实现任意四则运算表达式求值算法_C 语言

本文实例讲述了java实现任意四则运算表达式求值算法.分享给大家供大家参考.具体分析如下: 该程序用于计算任意四则运算表达式.如 4 * ( 10 + 2 ) + 1 的结果应该为 49. 算法说明: 1. 首先定义运算符优先级.我们用一个 Map<String, Map<String, String>> 来保存优先级表.这样我们就可以通过下面的方式来计算两个运算符的优先级了: /** * 查表得到op1和op2的优先级 * @param op1 运算符1 * @param op2

Java实现DES加解密算法解析_java

本文实例讲述了Java实现DES加解密算法解析.分享给大家供大家参考,具体如下:   简介: 数据加密算法(Data Encryption Algorithm,DEA)是一种对称加密算法,很可能是使用最广泛的密钥系统,特别是在保护金融数据的安全中,最初开发的DEA是嵌入硬件中的.通常,自动取款机(Automated Teller Machine,ATM)都使用DEA.它出自IBM的研究工作,IBM也曾对它拥有几年的专利权,但是在1983年已到期后,处于公有范围中,允许在特定条件下可以免除专利使用

Java实现的各种排序算法(插入排序、选择排序算法、冒泡排序算法)_java

一.插入排序算法实现java版本 public static int[] insert_sort(int[] a) { for (int i = 0; i < a.length; i++) { for(int j=i+1;j>0&&j<a.length;j--) { if(a[j]<a[j-1]) { int tmp = a[j]; //这样定义初始化逻辑上是可以的,j变量,每次tmp的值变化的 a[j] = a[j-1]; a[j-1] = tmp; } } }

图文讲解Java中实现quickSort快速排序算法的方法_java

相对冒泡排序.选择排序等算法而言,快速排序的具体算法原理及实现有一定的难度.为了更好地理解快速排序,我们仍然以举例说明的形式来详细描述快速排序的算法原理.在前面的排序算法中,我们以5名运动员的身高排序问题为例进行讲解,为了更好地体现快速排序的特点,这里我们再额外添加3名运动员.实例中的8名运动员及其身高信息详细如下(F.G.H为新增的运动员): A(181).B(169).C(187).D(172).E(163).F(191).G(189).H(182) 在前面的排序算法中,这些排序都是由教练主