问题
I am writing code for the external merge sort. The idea is that the input files contain too many numbers to be stored in an array so you read some of it and put it into files to be stored. Here's my code. While it runs fast, it is not fast enough. I was wondering if you can think of any improvements I can make on the code. Note that at first, I sort every 1m integers together so I skip iterations of the merging algorithm.
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.security.DigestInputStream;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Arrays;
public class ExternalSort {
public static void sort(String f1, String f2) throws Exception {
RandomAccessFile raf1 = new RandomAccessFile(f1, "rw");
RandomAccessFile raf2 = new RandomAccessFile(f2, "rw");
int fileByteSize = (int) (raf1.length() / 4);
int size = Math.min(1000000, fileByteSize);
externalSort(f1, f2, size);
boolean writeToOriginal = true;
DataOutputStream dos;
while (size <= fileByteSize) {
if (writeToOriginal) {
raf1.seek(0);
dos = new DataOutputStream(new BufferedOutputStream(
new MyFileOutputStream(raf1.getFD())));
} else {
raf2.seek(0);
dos = new DataOutputStream(new BufferedOutputStream(
new MyFileOutputStream(raf2.getFD())));
}
for (int i = 0; i < fileByteSize; i += 2 * size) {
if (writeToOriginal) {
dos = merge(f2, dos, i, size);
} else {
dos = merge(f1, dos, i, size);
}
}
dos.flush();
writeToOriginal = !writeToOriginal;
size *= 2;
}
if (writeToOriginal)
{
raf1.seek(0);
raf2.seek(0);
dos = new DataOutputStream(new BufferedOutputStream(
new MyFileOutputStream(raf1.getFD())));
int i = 0;
while (i < raf2.length() / 4){
dos.writeInt(raf2.readInt());
i++;
}
dos.flush();
}
}
public static void externalSort(String f1, String f2, int size) throws Exception{
RandomAccessFile raf1 = new RandomAccessFile(f1, "rw");
RandomAccessFile raf2 = new RandomAccessFile(f2, "rw");
int fileByteSize = (int) (raf1.length() / 4);
int[] array = new int[size];
DataInputStream dis = new DataInputStream(new BufferedInputStream(
new MyFileInputStream(raf1.getFD())));
DataOutputStream dos = new DataOutputStream(new BufferedOutputStream(
new MyFileOutputStream(raf2.getFD())));
int count = 0;
while (count < fileByteSize){
for (int k = 0; k < size; ++k){
array[k] = dis.readInt();
}
count += size;
Arrays.sort(array);
for (int k = 0; k < size; ++k){
dos.writeInt(array[k]);
}
}
dos.flush();
raf1.close();
raf2.close();
dis.close();
dos.close();
}
public static DataOutputStream merge(String file,
DataOutputStream dos, int start, int size) throws IOException {
RandomAccessFile raf = new RandomAccessFile(file, "rw");
RandomAccessFile raf2 = new RandomAccessFile(file, "rw");
int fileByteSize = (int) (raf.length() / 4);
raf.seek(4 * start);
raf2.seek(4 *start);
DataInputStream dis = new DataInputStream(new BufferedInputStream(
new MyFileInputStream(raf.getFD())));
DataInputStream dis3 = new DataInputStream(new BufferedInputStream(
new MyFileInputStream(raf2.getFD())));
int i = 0;
int j = 0;
int max = size * 2;
int a = dis.readInt();
int b;
if (start + size < fileByteSize) {
dis3.skip(4 * size);
b = dis3.readInt();
} else {
b = Integer.MAX_VALUE;
j = size;
}
while (i + j < max) {
if (j == size || (a <= b && i != size)) {
dos.writeInt(a);
i++;
if (start + i == fileByteSize) {
i = size;
} else if (i != size) {
a = dis.readInt();
}
} else {
dos.writeInt(b);
j++;
if (start + size + j == fileByteSize) {
j = size;
} else if (j != size) {
b = dis3.readInt();
}
}
}
raf.close();
raf2.close();
return dos;
}
public static void main(String[] args) throws Exception {
String f1 = args[0];
String f2 = args[1];
sort(f1, f2);
}
}
回答1:
You might wish to merge k>2 segments at a time. This reduces the amount of I/O from n log k / log 2 to n log n / log k.
Edit: In pseudocode, this would look something like this:
void sort(List list) {
if (list fits in memory) {
list.sort();
} else {
sublists = partition list into k about equally big sublists
for (sublist : sublists) {
sort(sublist);
}
merge(sublists);
}
}
void merge(List[] sortedsublists) {
keep a pointer in each sublist, which initially points to its first element
do {
find the pointer pointing at the smallest element
add the element it points to to the result list
advance that pointer
} until all pointers have reached the end of their sublist
return the result list
}
To efficiently find the "smallest" pointer, you might employ a PriorityQueue
.
回答2:
I would use memory mapped files. It can be as much as 10x faster than using this type of IO. I suspect it will be much faster in this case as well. The mapped buffers use virtual memory rather heap space to store data and can be larger than your available physical memory.
回答3:
We have implemented a public domain external sort in Java:
http://code.google.com/p/externalsortinginjava/
It might be faster than yours. We use strings and not integers, but you could easily modify our code by substituting integers for strings (the code was made hackable by design). At the very least, you can compare with our design.
Looking at your code, it seems like you are reading the data in units of integers. So IO will be a bottleneck I would guess. With external memory algorithms, you want to read and write blocks of data---especially in Java.
回答4:
You are sorting integers so you should check out radix sort. The core idea of radix sort is that you can sort n byte integers with n passes through the data with radix 256.
You can combine this with merge sort theory.
来源:https://stackoverflow.com/questions/8402106/how-to-speed-up-external-merge-sort-in-java