compare case class fields with sub fields of another case class in scala

余生长醉 提交于 2021-01-27 05:43:11

问题


I have the following 3 case classes:

case class Profile(name: String,
                   age: Int,
                   bankInfoData: BankInfoData,
                   userUpdatedFields: Option[UserUpdatedFields])

case class BankInfoData(accountNumber: Int,
                        bankAddress: String,
                        bankNumber: Int,
                        contactPerson: String,
                        phoneNumber: Int,
                        accountType: AccountType)

case class UserUpdatedFields(contactPerson: String,
                             phoneNumber: Int,
                             accountType: AccountType)

this is just enums, but i added anyway:

sealed trait AccountType extends EnumEntry

object AccountType extends Enum[AccountType] {
  val values: IndexedSeq[AccountType] = findValues

  case object Personal extends AccountType

  case object Business extends AccountType

}

my task is - i need to write a funcc Profile and compare UserUpdatedFields(all of the fields) with SOME of the fields in BankInfoData...this func is to find which fields where updated.

so I wrote this func:

def findDiff(profile: Profile): Seq[String] = {
  var listOfFieldsThatChanged: List[String] = List.empty
  if (profile.bankInfoData.contactPerson != profile.userUpdatedFields.get.contactPerson){
    listOfFieldsThatChanged = listOfFieldsThatChanged :+ "contactPerson"
  }
  if (profile.bankInfoData.phoneNumber != profile.userUpdatedFields.get.phoneNumber) {
    listOfFieldsThatChanged = listOfFieldsThatChanged :+ "phoneNumber"
  }
  if (profile.bankInfoData.accountType != profile.userUpdatedFields.get.accountType) {
    listOfFieldsThatChanged = listOfFieldsThatChanged :+ "accountType"
  }
  listOfFieldsThatChanged
}

val profile =
  Profile(
    "nir",
    34,
    BankInfoData(1, "somewhere", 2, "john", 123, AccountType.Personal),
    Some(UserUpdatedFields("lee", 321, AccountType.Personal))
  )

findDiff(profile)

it works, but wanted something cleaner..any suggestions?


回答1:


A simple improvement would be to introduce a trait

trait Fields {
  val contactPerson: String
  val phoneNumber: Int
  val accountType: AccountType

  def findDiff(that: Fields): Seq[String] = Seq(
    Some(contactPerson).filter(_ != that.contactPerson).map(_ => "contactPerson"),
    Some(phoneNumber).filter(_ != that.phoneNumber).map(_ => "phoneNumber"),
    Some(accountType).filter(_ != that.accountType).map(_ => "accountType")
  ).flatten
}

case class BankInfoData(accountNumber: Int,
                          bankAddress: String,
                          bankNumber: Int,
                          contactPerson: String,
                          phoneNumber: Int,
                          accountType: String) extends Fields

case class UserUpdatedFields(contactPerson: String,
                           phoneNumber: Int,
                           accountType: AccountType) extends Fields

so it was possible to call

BankInfoData(...). findDiff(UserUpdatedFields(...))

If you want to further-improve and avoid naming all the fields multiple times, for example shapeless could be used to do it compile time. Not exactly the same but something like this to get started. Or use reflection to do it runtime like this answer.




回答2:


Each case class extends Product interface so we could use it to convert case classes into sets of (field, value) elements. Then we can use set operations to find the difference. For example,

  def findDiff(profile: Profile): Seq[String] = {
    val userUpdatedFields = profile.userUpdatedFields.get
    val bankInfoData = profile.bankInfoData

    val updatedFieldsMap = userUpdatedFields.productElementNames.zip(userUpdatedFields.productIterator).toMap
    val bankInfoDataMap = bankInfoData.productElementNames.zip(bankInfoData.productIterator).toMap
    val bankInfoDataSubsetMap = bankInfoDataMap.view.filterKeys(userUpdatedFieldsMap.keys.toList.contains)
    (bankInfoDataSubsetMap.toSet diff updatedFieldsMap.toSet).toList.map { case (field, value) => field }
  }

Now findDiff(profile) should output List(phoneNumber, contactPerson). Note we are using productElementNames from Scala 2.13 to get the filed names which we then zip with corresponding values

userUpdatedFields.productElementNames.zip(userUpdatedFields.productIterator)

Also we rely on filterKeys and diff.




回答3:


That would be a very easy task to achieve if it would be an easy way to convert case class to map. Unfortunately, case classes don't offer that functionality out-of-box yet in Scala 2.12 (as Mario have mentioned it will be easy to achieve in Scala 2.13).

There's a library called shapeless, that offers some generic programming utilities. For example, we could write an extension function toMap using Record and ToMap from shapeless:

object Mappable {
  implicit class RichCaseClass[X](val x: X) extends AnyVal {
    import shapeless._
    import ops.record._

    def toMap[L <: HList](
        implicit gen: LabelledGeneric.Aux[X, L],
        toMap: ToMap[L]
    ): Map[String, Any] =
      toMap(gen.to(x)).map{
        case (k: Symbol, v) => k.name -> v
      }
    }
}

Then we could use it for findDiff:

def findDiff(profile: Profile): Seq[String] = {
  import Mappable._

  profile match {
    case Profile(_, _, bankInfo, Some(userUpdatedFields)) =>
      val bankInfoMap = bankInfo.toMap
      userUpdatedFields.toMap.toList.flatMap{
        case (k, v) if bankInfoMap.get(k).exists(_ != v) => Some(k)
        case _ => None
      }
    case _ => Seq()
  }
}


来源:https://stackoverflow.com/questions/56099428/compare-case-class-fields-with-sub-fields-of-another-case-class-in-scala

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!