문제
요약
- 위와 같은 도형 5개가 있다. 도형의 모든 칸이 격자 안에 올라가게 놓고, 도형에 포개진 숫자들을 더한 값의 최댓값을 구한다.
분류
- 구현
- 브루트포스
풀이
1. 브루트포스
- 케이스가 21개밖에 안 된다. 그냥 구현한다.
- 약간의 포인트는 반복되는 것들을 입력해야할 때 공통되는 것을 먼저 쓰고 복붙하는 것도 좋지만, 전체를 쓰고 각각의 케이스에서 없애야할 부분을 지워내는 것도 방법이다.
- sys.stdin.readline 필수!
import sys
input = sys.stdin.readline
n, m = map(int, input().split())
arr = [
list(map(int, input().split())) for _ in range(n)
]
answer = 0
for r in range(n):
for c in range(m):
temp = []
# ㅡ
if c + 3 < m:
temp.append(sum(arr[r][c:c+4]))
# |
if r + 3 < n:
temp.append(sum([val[c] for val in arr[r:r+4]]))
# ㅁ
if r + 1 < n and c + 1 < m:
temp.append(sum(
[arr[i][j] for i in range(r, r + 2) for j in range(c, c + 2)]))
# 2 3
if r + 1 < n and c + 2 < m:
temp += [
sum([arr[r][c+1], arr[r+1][c], arr[r+1][c+1], arr[r+1][c+2]]),
sum([arr[r][c], arr[r][c+1], arr[r][c+2], arr[r+1][c+1]]),
sum([arr[r][c+1], arr[r][c+2], arr[r+1][c], arr[r+1][c+1]]),
sum([arr[r][c], arr[r][c+1], arr[r+1][c+1], arr[r+1][c+2]]),
sum([arr[r][c], arr[r][c+1], arr[r][c+2], arr[r+1][c+2]]),
sum([arr[r][c], arr[r+1][c], arr[r+1][c+1], arr[r+1][c+2]]),
sum([arr[r][c], arr[r][c+1], arr[r][c+2], arr[r+1][c]]),
sum([arr[r][c+2], arr[r+1][c], arr[r+1][c+1], arr[r+1][c+2]]),
]
# 3 2
if r + 2 < n and c + 1 < m:
temp += [
sum([arr[r][c], arr[r+1][c], arr[r+1][c+1], arr[r+2][c]]),
sum([arr[r][c+1], arr[r+1][c], arr[r+1][c+1], arr[r+2][c+1]]),
sum([arr[r][c+1], arr[r+1][c], arr[r+1][c+1], arr[r+2][c]]),
sum([arr[r][c], arr[r+1][c], arr[r+1][c+1], arr[r+2][c+1]]),
sum([arr[r][c], arr[r][c+1], arr[r+1][c], arr[r+2][c]]),
sum([arr[r][c+1], arr[r+1][c+1], arr[r+2][c], arr[r+2][c+1]]),
sum([arr[r][c], arr[r+1][c], arr[r+2][c], arr[r+2][c+1]]),
sum([arr[r][c], arr[r][c+1], arr[r+1][c+1], arr[r+2][c+1]]),
]
answer = max([answer, *temp])
print(answer)
2. DFS
input = sys.stdin.readline
dx = [-1, 0, 1, 0]
dy = [0, 1, 0, -1]
def dfs(r, c, idx, total):
global answer
if answer >= total + max_val * (3-idx):
return
if idx == 3:
answer = max(answer, total)
return
else:
for i in range(4):
nr = r + dx[i]
nc = c + dy[i]
if 0 <= nr < N and 0 <= nc < M and visit[nr][nc] == 0:
if idx == 1:
visit[nr][nc] = 1
dfs(r, c, idx+1, total+arr[nr][nc])
visit[nr][nc] = 0
visit[nr][nc] = 1
dfs(nr, nc, idx+1, total+arr[nr][nc])
visit[nr][nc] = 0
N, M = map(int, input().split())
arr = [list(map(int, input().split())) for _ in range(N)]
visit = [([0]*M) for _ in range(N)]
answer = 0
max_val = max(map(max, arr))
for r in range(N):
for c in range(M):
visit[r][c] = 1
dfs(r, c, 0, arr[r][c])
visit[r][c] = 0
print(answer)