Koder / 박성훈
article thumbnail

처음으로 메모리 초과가 나온 문제.

이제부턴 메모리도 신경써서 풀어야겠다.

https://www.acmicpc.net/problem/2096

 

2096번: 내려가기

첫째 줄에 N(1 ≤ N ≤ 100,000)이 주어진다. 다음 N개의 줄에는 숫자가 세 개씩 주어진다. 숫자는 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 중의 하나가 된다.

www.acmicpc.net

처음에는 배열을 maxdp를 50만정도, mindp를 50만, 입력도 50만으로

엄청많이 만들어서 메모리 초과가 나왔으나,

어디선가 들었던 기법을 통해 메모리를 대폭 절감해 통과할 수 있었던 문제이다.

#include <stdio.h>

#define MIN(a, b) (((a) < (b)) ? (a) : (b))
#define MAX(a, b) (((a) > (b)) ? (a) : (b))

int n,i,j;
int maxdp[2][5] = {0};
int mindp[2][5] = {0};
int input[100001][5] = {0};
int maxsum = -1;
int minsum = 1234567;


int main(){
	scanf("%d", &n);
	for(i=0; i<n; i++) for(int j=1; j<=3; j++) scanf("%d", &input[i][j]);
	
	for(i=1; i<=3; i++){
		maxdp[0][i] = input[0][i];
		mindp[0][i] = input[0][i];
	}
	
	for(i=0; i<n; i++){
		mindp[i%2][0] = mindp[i%2][4] = 1234567;
		for(j=1; j<=3; j++){
			maxdp[(i+1)%2][j] = MAX(MAX(maxdp[i%2][j+1], maxdp[i%2][j-1]), maxdp[i%2][j]) + input[i+1][j];
			mindp[(i+1)%2][j] = MIN(MIN(mindp[i%2][j+1], mindp[i%2][j-1]), mindp[i%2][j]) + input[i+1][j];
		}
	}
	for(i=1; i<=3; i++){
		maxsum = MAX(maxsum, maxdp[n%2][i]);
		minsum = MIN(minsum, mindp[n%2][i]);
	}
	printf("%d %d", maxsum, minsum);
	return 0;
}

나름 메모리 아끼려고 원래 max랑 min은 함수로 만드는데, define으로 바꿔보았다.

메모리 절약 효과가 있었는지는 잘 모르겠다. 그냥 감이랄까?

입력을 다 받고, dp배열들의 첫번째 줄은 그냥 입력값을 그대로 넣어주었다.

포문부분이 좀 중요한데,

보면 인덱스에 전부 %2가 되있는 모습을 볼 수 있을것이다.

얘를 빼고 보자면

for(i=0; i<n; i++){
	mindp[i][0] = mindp[i][4] = 1234567;
	for(j=1; j<=3; j++){
		maxdp[i+1][j] = MAX(MAX(maxdp[i][j+1], maxdp[i][j-1]), maxdp[i][j]) + input[i+1][j];
		mindp[i+1][j] = MIN(MIN(mindp[i][j+1], mindp[i][j-1]), mindp[i][j]) + input[i+1][j];
	}
}

평범한 dp문제가 된다.

maxdp[i][j]의 값은 i번째로 수를 선택할때 j번 인덱스의 값을 고르면,

그 때의 최대점수이고,

mindp[i][j]의 값은 최소점수를 의미한다.

j번 인덱스의 값을 고르는 부분이 뒤에 붙어있는 input[i+1][j]이다.

그런데 여기서 항상 계산하는 값은 두개뿐이라는걸 알 수 있다.

i번째 인덱스와 i+1번째 인덱스.

이 두개를 위해서 10만짜리 배열을 만들어서 메모리 초과가 나왔던 것이다.

i를 2로 나눴을때의 나머지와 i+1을 2로 나눴을때의 나머지는 항상 다르므로

(하나가 짝수이면, 나머지 하나는 홀수로 서로 다름)

% 연산을 통해 계산에 필요한 부분만 남기고 메모리를 아낄 수 있었다.

for(i=1; i<=3; i++){
	maxsum = MAX(maxsum, maxdp[n%2][i]);
	minsum = MIN(minsum, mindp[n%2][i]);
}

그리고 마지막으로 출력해주는 값 maxsum,minsum 두개에

dp[n][k] (1 <= k <= 3) 범위의 최대값을 넣어준다.

이 역시 %2연산을 해줘야 제대로 값이 들어갈 수 있다.

이 두값을 출력하면 AC.

+

이제 나도 골드!

 

[이 글은 작성 시점이 2020년 8월 11일입니다]

반응형