본문 바로가기

[BOJ] - JAVA

[백준] 2740 : 행렬 곱셈 JAVA 풀이

 

분할정복 문제를 조금 풀어봤는데

아직도 어떻게 풀어야 할지 감이 안 와서 검색을 해보니 

단순히 행렬곱셈을 하는 코드도 정답처리된다기에 우선 그렇게 풀어보았다.

 

import java.io.*;
import java.util.*;

public class Main{
    static int[][] A;
    static int[][] B;
    static int[][] C;
    
    public static void main(String[] args)throws IOException{
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine(), " ");
        
        int N = Integer.parseInt(st.nextToken());
        int M = Integer.parseInt(st.nextToken());
        
        A = new int[N][M];
        for(int i=0;i<N;i++){
	        st = new StringTokenizer(br.readLine(), " ");
	        for(int j=0;j<M;j++){
	         A[i][j] = Integer.parseInt(st.nextToken());
            }
        }
        st = new StringTokenizer(br.readLine(), " ");
        
        M = Integer.parseInt(st.nextToken());
        int K = Integer.parseInt(st.nextToken());
        
        B = new int[M][K];
        C = new int[N][K];
        
        for(int i=0;i<M;i++){
	        st = new StringTokenizer(br.readLine(), " ");
	        for(int j=0;j<K;j++){
	         B[i][j] = Integer.parseInt(st.nextToken());
            }
        }
        
        for(int i=0;i<N;i++){
            for(int j=0;j<K;j++){
                
                for(int k=0;k<M;k++){
                    C[i][j] += A[i][k]*B[k][j];
                }
            }
        }
        
        for(int i=0;i<N;i++){
            for(int j=0;j<K;j++){
                System.out.print(C[i][j]+" ");
            }
            System.out.println();
        }
        
    }
}

 

그리고 아래 코드는

https://st-lab.tistory.com/245

 

[백준] 2740번 : 행렬 곱셈 - JAVA [자바]

www.acmicpc.net/problem/2740 2740번: 행렬 곱셈 첫째 줄에 행렬 A의 크기 N 과 M이 주어진다. 둘째 줄부터 N개의 줄에 행렬 A의 원소 M개가 순서대로 주어진다. 그 다음 줄에는 행렬 B의 크기 M과 K가 주어진

st-lab.tistory.com

 

이 포스트를 참고해 스트라센 알고리즘을 적용해서 풀어본 코드이다.

그런데..이 방식으로는 메모리 초과가 나온다.

하지만 분할정복으로 풀어봤다는 점에 의의를 둔다.

 

import java.io.*;
import java.util.*;

public class bj_2740 {
	public static void main(String[] args) throws IOException {
		// TODO Auto-generated method stub
		
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine(), " ");
        
        int N = Integer.parseInt(st.nextToken());
        int M = Integer.parseInt(st.nextToken());
        
        int[][] A = new int[128][128];
        for(int i=0;i<N;i++){
	        st = new StringTokenizer(br.readLine(), " ");
	        for(int j=0;j<M;j++){
	         A[i][j] = Integer.parseInt(st.nextToken());
            }
        }
        st = new StringTokenizer(br.readLine(), " ");
        M = Integer.parseInt(st.nextToken());
        int K = Integer.parseInt(st.nextToken());
        
        int[][] B = new int[128][128];
        
        for(int i=0;i<M;i++){
	        st = new StringTokenizer(br.readLine(), " ");
	        for(int j=0;j<K;j++){
	         B[i][j] = Integer.parseInt(st.nextToken());
            }
        }
        
        int big = Math.max(Math.max(M, N), K);
        
        int size = 1;
        while(true) {
        	if(size>=big) {
        		break;
        	}
        	size *= 2;
        }
        
        int[][] C = multiply(A,B,size);
        
        StringBuilder sb = new StringBuilder();
        
        for(int i=0;i<N;i++) {
        	for(int j=0;j<K;j++) {
        		sb.append(C[i][j]+" ");
        	}
        	sb.append('\n');
        }
        System.out.print(sb);
	}
	
	static int[][] multiply(int[][] A, int[][] B, int size){
		
		int[][] C = new int[size][size];
		
		if(size==1) {
			C[0][0] = A[0][0]*B[0][0];
			return C;
		}
		
		int newSize = size/2;
		
		int[][] a11 = subArray(A, 0, 0, newSize);
		int[][] a12 = subArray(A, 0, newSize, newSize);
		int[][] a21 = subArray(A, newSize, 0, newSize);
		int[][] a22 = subArray(A, newSize, newSize, newSize);
		
		int[][] b11 = subArray(B, 0, 0, newSize);
		int[][] b12 = subArray(B, 0, newSize, newSize);
		int[][] b21 = subArray(B, newSize, 0, newSize);
		int[][] b22 = subArray(B, newSize, newSize, newSize);
		
		int[][] M1 = multiply(add(a11, a22,newSize),add(b11, b22,newSize),newSize);
		
		int[][] M2 = multiply(add(a21, a22,newSize),b11,newSize);
		
		int[][] M3 = multiply(a11,sub(b12, b22,newSize),newSize);
		
		int[][] M4 = multiply(a22,sub(b21, b11,newSize),newSize);
		
		int[][] M5 = multiply(add(a11, a12,newSize),b22,newSize);
		
		int[][] M6 = multiply(sub(a21, a11,newSize),add(b11, b12,newSize),newSize);
		
		int[][] M7 = multiply(sub(a12, a22,newSize),add(b21, b22,newSize),newSize);
		
		int[][] c11 = sub(add(add(M1, M4,newSize),M7,newSize),M5,newSize);
		int[][] c12 = add(M3, M5,newSize);
		int[][] c21 = add(M2, M4,newSize);
		int[][] c22 = sub(add(add(M1, M3,newSize),M6,newSize),M2,newSize);
		
		
		merge(c11, C, 0, 0, newSize);
		merge(c12, C, 0, newSize, newSize);
		merge(c21, C, newSize, 0, newSize);
		merge(c22, C, newSize, newSize, newSize);
		
		return C;
	}
	
	static void merge(int[][] src, int[][] C, int row, int col, int size) {
		
		for(int i=0;i<size;i++) {
			for(int j=0;j<size;j++) {
				C[row+i][col+j] = src[i][j];
			}
		}
	}
	
	static int[][] add(int[][] A, int[][] B, int size){
		
		int[][] dest = new int[size][size];
		
		for(int i=0;i<size;i++) {
			for(int j=0;j<size;j++) {
				dest[i][j] = A[i][j]+B[i][j];
			}
		}
		
		return dest;
	}
	
	static int[][] sub(int[][] A, int[][] B, int size){
		
		int[][] dest = new int[size][size];
		
		for(int i=0;i<size;i++) {
			for(int j=0;j<size;j++) {
				dest[i][j] = A[i][j]-B[i][j];
			}
		}
		
		return dest;
	}
	
	static int[][] subArray(int[][] src, int row, int col, int size){
		
		int[][] dest = new int[size][size];
		for(int dest_i = 0, src_i = row; dest_i<size;dest_i++, src_i++) {
			for(int dest_j = 0, src_j = col; dest_j<size;dest_j++, src_j++) {
				dest[dest_i][dest_j] = src[src_i][src_j];
			}
		}
		return dest;
	}

}