benchmark1.py 3.85 KB
from functools import partial
import gc
import time
import sys

from py4j.java_gateway import JavaGateway, CallbackServerParameters


ITERATIONS_FOR_LENGTHY_METHOD = 3


class ComparablePython(object):

    def __init__(self, value):
        self.value = value

    def compareTo(self, obj):
        if obj is None:
            # Hack to return the value of this object.
            return self.value
        value = obj.compareTo(None)
        return self.value - value

    class Java:
        implements = ["java.lang.Comparable"]


def callStaticMethodNoParam(iterations, staticMethod):
    i = 0
    result = 0
    while i < iterations:
        result = staticMethod()
        i += 1
    # Make sure that the last result is returned so Python does not discard the
    # output value.
    return result


def callInstanceMethodWithShortParam(iterations, instanceMethod):
    shortParam = "Super Long Param"
    i = 0
    while i < iterations:
        instanceMethod(shortParam)
        instanceMethod(1)
        i += 1


def callFunc(iterations, func):
    i = 0
    result = None
    while i < iterations:
        result = func()
        i += 1
    return result


def benchmark(name, func):
    start = time.time()
    func()
    stop = time.time()
    print("{0} - {1}".format(stop - start, name))
    gc.collect()


def main(iterations):
    small_iterations = iterations / 10 if iterations > 10 else iterations
    gateway = JavaGateway(
        callback_server_parameters=CallbackServerParameters())
    currentTimeMillis = gateway.jvm.java.lang.System.currentTimeMillis
    sb = gateway.jvm.java.lang.StringBuilder()
    append = sb.append
    sb2 = gateway.jvm.java.lang.StringBuilder()

    def reflection():
        sb2.append(2)
        sb2.append("hello")

    def constructorAndMemoryManagement():
        sb = gateway.jvm.java.lang.StringBuilder("Hello World")
        sb.append("testing")

    def javaCollection():
        al = gateway.jvm.java.util.ArrayList()
        al.append("test")
        al.append(1)
        al.append(True)
        len(al)
        result = []
        for elem in al:
            result.append(elem)
        return result

    def callBack():
        al = gateway.jvm.java.util.ArrayList()
        cp10 = ComparablePython(10)
        cp1 = ComparablePython(1)
        cp5 = ComparablePython(5)
        cp7 = ComparablePython(7)
        al.append(cp10)
        al.append(cp1)
        al.append(cp5)
        al.append(cp7)
        gateway.jvm.java.util.Collections.sort(al)

    def longParamCall():
        longParam = "s" * 1024 * 1024 * 10
        sb = gateway.jvm.java.lang.StringBuilder()
        sb.append(longParam)
        sb.toString()

    benchmark(
        "callStaticMethodNoParam",
        partial(callStaticMethodNoParam, iterations, currentTimeMillis))
    benchmark(
        "callInstanceMethodWithShortParam",
        partial(callInstanceMethodWithShortParam, iterations, append))
    benchmark(
        "callWithReflection",
        partial(callFunc, iterations, reflection))
    benchmark(
        "constructorAndMemoryManagement",
        partial(callFunc, iterations, constructorAndMemoryManagement))
    benchmark(
        "longParamAndMemoryManagement",
        partial(callFunc, ITERATIONS_FOR_LENGTHY_METHOD, longParamCall))
    benchmark(
        "javaCollection",
        partial(callFunc, small_iterations, javaCollection))
    benchmark(
        "callBack",
        partial(callFunc, small_iterations, callBack))
    gateway.shutdown()


if __name__ == "__main__":
    # 1. Run py4j-java, e.g.,
    #    cd py4j-java; ./gradlew testsJar;
    #    java -Xmx4096m -cp build/libs/py4j-tests-0.10.0.jar \
    #    py4j.example.ExampleApplication
    # 2. Run python program:
    #    cd py4j-python; export PYTHONPATH=src
    #    python3 src/py4j/tests/benchmark1.py
    iterations = 100000
    if len(sys.argv) > 1:
        iterations = int(sys.argv[1])
    main(iterations)