PySparkでコントロールブレイク処理

お題は次のエントリです。

gonsuke777.hatenablog.com

上記エントリではいわゆるコントロールブレイク処理(ソート済みのレコードを読み込み、キー項目ごとにグループ分けして行う処理のことでキーブレイク処理と呼ぶことも)を 1 本の SQL でスマートに行っています。これと同じことを PySpark でやってみるという話です。

次のような CSV ファイルを用意しておきます。

sales_date,jan_code,sales_cnt
2014/10/06,AAA,100
2014/10/07,AAA,200
2014/10/08,BBB,100
2014/10/09,BBB,150
2014/10/10,BBB,189
2014/10/11,CCC,120
2014/10/12,CCC,111
2014/10/13,AAA,210
2014/10/14,AAA,545
2014/10/15,AAA,90
2014/10/16,CCC,90

これを Spark DataFrame に読み込みます。

from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DateType

schema = StructType([
  StructField('sales_date', DateType()),
  StructField('jan_code', StringType()),
  StructField('sales_cnt', IntegerType())
])

df = spark.read.csv('<path-to-csv>', schema=schema, header=True, dateFormat='yyyy/MM/dd')
df.show()
# +----------+--------+---------+
# |sales_date|jan_code|sales_cnt|
# +----------+--------+---------+
# |2014-10-06|     AAA|      100|
# |2014-10-07|     AAA|      200|
# |2014-10-08|     BBB|      100|
# |2014-10-09|     BBB|      150|
# |2014-10-10|     BBB|      189|
# |2014-10-11|     CCC|      120|
# |2014-10-12|     CCC|      111|
# |2014-10-13|     AAA|      210|
# |2014-10-14|     AAA|      545|
# |2014-10-15|     AAA|       90|
# |2014-10-16|     CCC|       90|
# +----------+--------+---------+

元の SQL では ROW_NUMBER ウィンドウ関数を使って単純ソートした場合の連続値と jan_code で区切りつつソートした場合の連続値を割り振っていますが、PySpark (Spark SQL) でも pyspark.sql.functions.row_number という同じ関数があります。

from pyspark.sql import Window
from pyspark.sql.functions import row_number

# SQLでの ROW_NUMBER() OVER(ORDER BY SALES_DATE) に相当
df = df.withColumn('simple_sq', row_number().over(Window.orderBy('sales_date')))

# SQLでの ROW_NUMBER() OVER(PARTITION BY JAN_CODE ORDER BY SALES_DATE) に相当
df = df.withColumn('part_jan_sq', row_number().over(Window.partitionBy('jan_code').orderBy('sales_date')))

パーティションとソート順は pyspark.sql.Window クラスのファクトリメソッドを使って生成する pyspark.sql.WindowSpec オブジェクトとして渡します。

あとは distance を計算すれば集約カラムが作られますね。

from pyspark.sql.functions import col

df = df.withColumn('distance', col('simple_sq') - col('part_jan_sq'))
df.orderBy('sales_date').show()

# +----------+--------+---------+---------+-----------+--------+
# |sales_date|jan_code|sales_cnt|simple_sq|part_jan_sq|distance|
# +----------+--------+---------+---------+-----------+--------+
# |2014-10-06|     AAA|      100|        1|          1|       0|
# |2014-10-07|     AAA|      200|        2|          2|       0|
# |2014-10-08|     BBB|      100|        3|          1|       2|
# |2014-10-09|     BBB|      150|        4|          2|       2|
# |2014-10-10|     BBB|      189|        5|          3|       2|
# |2014-10-11|     CCC|      120|        6|          1|       5|
# |2014-10-12|     CCC|      111|        7|          2|       5|
# |2014-10-13|     AAA|      210|        8|          3|       5|
# |2014-10-14|     AAA|      545|        9|          4|       5|
# |2014-10-15|     AAA|       90|       10|          5|       5|
# |2014-10-16|     CCC|       90|       11|          3|       8|
# +----------+--------+---------+---------+-----------+--------+

集約のためのキーができたので、集約を行っておしまい。

grouped_df = df.groupBy(['jan_code', 'distance']) \
  .agg(min('sales_date').alias('sales_date_first'), \
       max('sales_date').alias('sales_date_last'), \
       sum('sales_cnt').alias('cnt_sum'))
grouped_df.orderBy('sales_date_first').show()

# +--------+--------+----------------+---------------+-------+
# |jan_code|distance|sales_date_first|sales_date_last|cnt_sum|
# +--------+--------+----------------+---------------+-------+
# |     AAA|       0|      2014-10-06|     2014-10-07|    300|
# |     BBB|       2|      2014-10-08|     2014-10-10|    439|
# |     CCC|       5|      2014-10-11|     2014-10-12|    231|
# |     AAA|       5|      2014-10-13|     2014-10-15|    845|
# |     CCC|       8|      2014-10-16|     2014-10-16|     90|
# +--------+--------+----------------+---------------+-------+