お題は次のエントリです。
上記エントリではいわゆるコントロールブレイク処理(ソート済みのレコードを読み込み、キー項目ごとにグループ分けして行う処理のことでキーブレイク処理と呼ぶことも)を 1 本の SQL でスマートに行っています。これと同じことを PySpark でやってみるという話です。
次のような CSV ファイルを用意しておきます。
sales_date,jan_code,sales_cnt 2014/10/06,AAA,100 2014/10/07,AAA,200 2014/10/08,BBB,100 2014/10/09,BBB,150 2014/10/10,BBB,189 2014/10/11,CCC,120 2014/10/12,CCC,111 2014/10/13,AAA,210 2014/10/14,AAA,545 2014/10/15,AAA,90 2014/10/16,CCC,90
これを Spark DataFrame に読み込みます。
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DateType schema = StructType([ StructField('sales_date', DateType()), StructField('jan_code', StringType()), StructField('sales_cnt', IntegerType()) ]) df = spark.read.csv('<path-to-csv>', schema=schema, header=True, dateFormat='yyyy/MM/dd') df.show() # +----------+--------+---------+ # |sales_date|jan_code|sales_cnt| # +----------+--------+---------+ # |2014-10-06| AAA| 100| # |2014-10-07| AAA| 200| # |2014-10-08| BBB| 100| # |2014-10-09| BBB| 150| # |2014-10-10| BBB| 189| # |2014-10-11| CCC| 120| # |2014-10-12| CCC| 111| # |2014-10-13| AAA| 210| # |2014-10-14| AAA| 545| # |2014-10-15| AAA| 90| # |2014-10-16| CCC| 90| # +----------+--------+---------+
元の SQL では ROW_NUMBER
ウィンドウ関数を使って単純ソートした場合の連続値と jan_code で区切りつつソートした場合の連続値を割り振っていますが、PySpark (Spark SQL) でも pyspark.sql.functions.row_number という同じ関数があります。
from pyspark.sql import Window from pyspark.sql.functions import row_number # SQLでの ROW_NUMBER() OVER(ORDER BY SALES_DATE) に相当 df = df.withColumn('simple_sq', row_number().over(Window.orderBy('sales_date'))) # SQLでの ROW_NUMBER() OVER(PARTITION BY JAN_CODE ORDER BY SALES_DATE) に相当 df = df.withColumn('part_jan_sq', row_number().over(Window.partitionBy('jan_code').orderBy('sales_date')))
パーティションとソート順は pyspark.sql.Window クラスのファクトリメソッドを使って生成する pyspark.sql.WindowSpec
オブジェクトとして渡します。
あとは distance を計算すれば集約カラムが作られますね。
from pyspark.sql.functions import col df = df.withColumn('distance', col('simple_sq') - col('part_jan_sq')) df.orderBy('sales_date').show() # +----------+--------+---------+---------+-----------+--------+ # |sales_date|jan_code|sales_cnt|simple_sq|part_jan_sq|distance| # +----------+--------+---------+---------+-----------+--------+ # |2014-10-06| AAA| 100| 1| 1| 0| # |2014-10-07| AAA| 200| 2| 2| 0| # |2014-10-08| BBB| 100| 3| 1| 2| # |2014-10-09| BBB| 150| 4| 2| 2| # |2014-10-10| BBB| 189| 5| 3| 2| # |2014-10-11| CCC| 120| 6| 1| 5| # |2014-10-12| CCC| 111| 7| 2| 5| # |2014-10-13| AAA| 210| 8| 3| 5| # |2014-10-14| AAA| 545| 9| 4| 5| # |2014-10-15| AAA| 90| 10| 5| 5| # |2014-10-16| CCC| 90| 11| 3| 8| # +----------+--------+---------+---------+-----------+--------+
集約のためのキーができたので、集約を行っておしまい。
grouped_df = df.groupBy(['jan_code', 'distance']) \ .agg(min('sales_date').alias('sales_date_first'), \ max('sales_date').alias('sales_date_last'), \ sum('sales_cnt').alias('cnt_sum')) grouped_df.orderBy('sales_date_first').show() # +--------+--------+----------------+---------------+-------+ # |jan_code|distance|sales_date_first|sales_date_last|cnt_sum| # +--------+--------+----------------+---------------+-------+ # | AAA| 0| 2014-10-06| 2014-10-07| 300| # | BBB| 2| 2014-10-08| 2014-10-10| 439| # | CCC| 5| 2014-10-11| 2014-10-12| 231| # | AAA| 5| 2014-10-13| 2014-10-15| 845| # | CCC| 8| 2014-10-16| 2014-10-16| 90| # +--------+--------+----------------+---------------+-------+